Commit e932712b authored by Priya Gupta's avatar Priya Gupta
Browse files

Change LR schedule to adjust according to batch size

parent 746a927c
...@@ -83,10 +83,10 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -83,10 +83,10 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182) (0.1, 91), (0.01, 136), (0.001, 182)
] ]
NUM_GPUS = flags_core.get_num_gpus(flags.FLAGS)
BASE_LEARNING_RATE = 0.1 * NUM_GPUS
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch): BASE_LEARNING_RATE = 0.1
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step. The learning rate starts at 0, then it increases linearly per step.
...@@ -115,11 +115,11 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch): ...@@ -115,11 +115,11 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
# break # break
# return learning_rate # return learning_rate
epoch = current_epoch + float(current_batch) / batches_per_epoch initial_learning_rate = BASE_LEARNING_RATE * batch_size / 128
learning_rate = BASE_LEARNING_RATE learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch: if current_epoch >= start_epoch:
learning_rate = BASE_LEARNING_RATE * mult learning_rate = initial_learning_rate * mult
else: else:
break break
return learning_rate return learning_rate
...@@ -140,6 +140,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -140,6 +140,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
super(LearningRateBatchScheduler, self).__init__() super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule self.schedule = schedule
self.batches_per_epoch = num_images / batch_size self.batches_per_epoch = num_images / batch_size
self.batch_size = batch_size
self.epochs = -1 self.epochs = -1
self.prev_lr = -1 self.prev_lr = -1
...@@ -149,7 +150,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -149,7 +150,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.epochs += 1 self.epochs += 1
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch, self.batches_per_epoch) lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)): if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.') raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr: if lr != self.prev_lr:
...@@ -273,7 +274,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -273,7 +274,7 @@ def run_cifar_with_keras(flags_obj):
tesorboard_callback = tf.keras.callbacks.TensorBoard( tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir) log_dir=flags_obj.model_dir)
# update_freq="batch") # Add this if want per batch logging. #update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule,
......
...@@ -81,9 +81,9 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -81,9 +81,9 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
] ]
BASE_LEARNING_RATE = 0.4 #0.128 BASE_LEARNING_RATE = 0.128
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch): def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step. The learning rate starts at 0, then it increases linearly per step.
...@@ -100,14 +100,15 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch): ...@@ -100,14 +100,15 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
initial_learning_rate = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch: if epoch < warmup_end_epoch:
# Learning rate increases linearly per step. # Learning rate increases linearly per step.
return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch return initial_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch: if epoch >= start_epoch:
learning_rate = BASE_LEARNING_RATE * mult learning_rate = initial_learning_rate * mult
else: else:
break break
return learning_rate return learning_rate
...@@ -128,6 +129,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -128,6 +129,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
super(LearningRateBatchScheduler, self).__init__() super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule self.schedule = schedule
self.batches_per_epoch = num_images / batch_size self.batches_per_epoch = num_images / batch_size
self.batch_size = batch_size
self.epochs = -1 self.epochs = -1
self.prev_lr = -1 self.prev_lr = -1
...@@ -137,7 +139,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -137,7 +139,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.epochs += 1 self.epochs += 1
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch, self.batches_per_epoch) lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)): if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.') raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr: if lr != self.prev_lr:
...@@ -187,6 +189,15 @@ def run_imagenet_with_keras(flags_obj): ...@@ -187,6 +189,15 @@ def run_imagenet_with_keras(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
""" """
# Set all random seeds to fixed values.
import random
import numpy as np
seed = 87654321
random.seed(seed)
np.random.seed(seed)
tf.random.set_random_seed(seed)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
...@@ -239,10 +250,11 @@ def run_imagenet_with_keras(flags_obj): ...@@ -239,10 +250,11 @@ def run_imagenet_with_keras(flags_obj):
# opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001) # opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
# I am setting an initial LR of 0.001 since this will be reset # I am setting an initial LR of 0.001 since this will be reset
# at the beginning of the training loop. # at the beginning of the training loop.
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9) # opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer: # TF Optimizer:
# opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.9) learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256
opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
...@@ -264,8 +276,8 @@ def run_imagenet_with_keras(flags_obj): ...@@ -264,8 +276,8 @@ def run_imagenet_with_keras(flags_obj):
time_callback = TimeHistory(flags_obj.batch_size) time_callback = TimeHistory(flags_obj.batch_size)
tesorboard_callback = tf.keras.callbacks.TensorBoard( tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir) log_dir=flags_obj.model_dir,
# update_freq="batch") # Add this if want per batch logging. update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule,
...@@ -280,7 +292,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -280,7 +292,7 @@ def run_imagenet_with_keras(flags_obj):
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, #lr_callback,
tesorboard_callback tesorboard_callback
], ],
verbose=1) verbose=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment