Commit 863a44ce authored by Toby Boyd's avatar Toby Boyd
Browse files

fixed lint issues.

parent 61ec2907
......@@ -23,8 +23,6 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core
......@@ -37,28 +35,33 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
]
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the provided scaling
factor.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 256
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return initial_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch
return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = initial_learning_rate * mult
learning_rate = initial_lr * mult
else:
break
return learning_rate
......@@ -106,15 +109,13 @@ def run(flags_obj):
else:
input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(
is_training=True,
train_input_dataset = input_fn(is_training=True,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(
is_training=False,
eval_input_dataset = input_fn(is_training=False,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
......@@ -149,7 +150,7 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None
history = model.fit(train_input_dataset,
model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
......@@ -162,7 +163,7 @@ def run(flags_obj):
verbose=1)
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment