Commit e932712b authored by Priya Gupta's avatar Priya Gupta
Browse files

Change LR schedule to adjust according to batch size

parent 746a927c
......@@ -83,10 +83,10 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182)
]
NUM_GPUS = flags_core.get_num_gpus(flags.FLAGS)
BASE_LEARNING_RATE = 0.1 * NUM_GPUS
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
BASE_LEARNING_RATE = 0.1
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step.
......@@ -115,11 +115,11 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
# break
# return learning_rate
epoch = current_epoch + float(current_batch) / batches_per_epoch
learning_rate = BASE_LEARNING_RATE
initial_learning_rate = BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = BASE_LEARNING_RATE * mult
if current_epoch >= start_epoch:
learning_rate = initial_learning_rate * mult
else:
break
return learning_rate
......@@ -140,6 +140,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.batches_per_epoch = num_images / batch_size
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
......@@ -149,7 +150,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch, self.batches_per_epoch)
lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
......@@ -273,7 +274,7 @@ def run_cifar_with_keras(flags_obj):
tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir)
# update_freq="batch") # Add this if want per batch logging.
#update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule,
......
......@@ -81,9 +81,9 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
BASE_LEARNING_RATE = 0.4 #0.128
BASE_LEARNING_RATE = 0.128
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step.
......@@ -100,14 +100,15 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
Returns:
Adjusted learning rate.
"""
initial_learning_rate = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch
return initial_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = BASE_LEARNING_RATE * mult
learning_rate = initial_learning_rate * mult
else:
break
return learning_rate
......@@ -128,6 +129,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.batches_per_epoch = num_images / batch_size
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
......@@ -137,7 +139,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch, self.batches_per_epoch)
lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
......@@ -187,6 +189,15 @@ def run_imagenet_with_keras(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
# Set all random seeds to fixed values.
import random
import numpy as np
seed = 87654321
random.seed(seed)
np.random.seed(seed)
tf.random.set_random_seed(seed)
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
......@@ -239,10 +250,11 @@ def run_imagenet_with_keras(flags_obj):
# opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
# I am setting an initial LR of 0.001 since this will be reset
# at the beginning of the training loop.
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer:
# opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.9)
learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256
opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus)
......@@ -264,8 +276,8 @@ def run_imagenet_with_keras(flags_obj):
time_callback = TimeHistory(flags_obj.batch_size)
tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir)
# update_freq="batch") # Add this if want per batch logging.
log_dir=flags_obj.model_dir,
update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule,
......@@ -280,7 +292,7 @@ def run_imagenet_with_keras(flags_obj):
steps_per_epoch=steps_per_epoch,
callbacks=[
time_callback,
lr_callback,
#lr_callback,
tesorboard_callback
],
verbose=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment