Commit 863a44ce authored by Toby Boyd's avatar Toby Boyd
Browse files

fixed lint issues.

parent 61ec2907
...@@ -23,8 +23,6 @@ from absl import flags ...@@ -23,8 +23,6 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -37,28 +35,33 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples ...@@ -37,28 +35,33 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
] ]
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size): def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the provided scaling Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
factor. provided scaling factor.
Args: Args:
current_epoch: integer, current epoch indexed from 0. current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0. current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 256 initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch: if epoch < warmup_end_epoch:
# Learning rate increases linearly per step. # Learning rate increases linearly per step.
return initial_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch: if epoch >= start_epoch:
learning_rate = initial_learning_rate * mult learning_rate = initial_lr * mult
else: else:
break break
return learning_rate return learning_rate
...@@ -106,15 +109,13 @@ def run(flags_obj): ...@@ -106,15 +109,13 @@ def run(flags_obj):
else: else:
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
train_input_dataset = input_fn( train_input_dataset = input_fn(is_training=True,
is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn( eval_input_dataset = input_fn(is_training=False,
is_training=False,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
...@@ -149,7 +150,7 @@ def run(flags_obj): ...@@ -149,7 +150,7 @@ def run(flags_obj):
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
history = model.fit(train_input_dataset, model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
callbacks=[ callbacks=[
...@@ -162,7 +163,7 @@ def run(flags_obj): ...@@ -162,7 +163,7 @@ def run(flags_obj):
verbose=1) verbose=1)
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset, model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment