Commit 1cd1999e authored by Amit Patankar's avatar Amit Patankar
Browse files

Setting the CIFAR parameters for the appropriate learning rate and num gpus.

parent 3570932e
......@@ -77,10 +77,14 @@ class TimeHistory(tf.keras.callbacks.Callback):
(batch, last_n_batches, examples_per_second))
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
# LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
# (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
# ]
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182)
]
BASE_LEARNING_RATE = 3.2 #0.128
NUM_GPUS = flags_core.get_num_gpus(flags.FLAGS)
BASE_LEARNING_RATE = 0.1 * NUM_GPUS
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
"""Handles linear scaling rule, gradual warmup, and LR decay.
......@@ -99,11 +103,19 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
Returns:
Adjusted learning rate.
"""
# epoch = current_epoch + float(current_batch) / batches_per_epoch
# warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
# if epoch < warmup_end_epoch:
# # Learning rate increases linearly per step.
# return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch
# for mult, start_epoch in LR_SCHEDULE:
# if epoch >= start_epoch:
# learning_rate = BASE_LEARNING_RATE * mult
# else:
# break
# return learning_rate
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = BASE_LEARNING_RATE * mult
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment