Commit 863a44ce authored by Toby Boyd's avatar Toby Boyd
Browse files

fixed lint issues.

parent 61ec2907
...@@ -23,8 +23,6 @@ from absl import flags ...@@ -23,8 +23,6 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -37,28 +35,33 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples ...@@ -37,28 +35,33 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
] ]
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size): def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the provided scaling Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
factor. provided scaling factor.
Args: Args:
current_epoch: integer, current epoch indexed from 0. current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0. current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 256 initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch: if epoch < warmup_end_epoch:
# Learning rate increases linearly per step. # Learning rate increases linearly per step.
return initial_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch: if epoch >= start_epoch:
learning_rate = initial_learning_rate * mult learning_rate = initial_lr * mult
else: else:
break break
return learning_rate return learning_rate
...@@ -71,7 +74,7 @@ def parse_record_keras(raw_record, is_training, dtype): ...@@ -71,7 +74,7 @@ def parse_record_keras(raw_record, is_training, dtype):
# Subtract one so that labels are in [0, 1000), and cast to float32 for # Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model. # Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1, label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32) dtype=tf.float32)
return image, label return image, label
...@@ -106,23 +109,21 @@ def run(flags_obj): ...@@ -106,23 +109,21 @@ def run(flags_obj):
else: else:
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
train_input_dataset = input_fn( train_input_dataset = input_fn(is_training=True,
is_training=True, data_dir=flags_obj.data_dir,
data_dir=flags_obj.data_dir, batch_size=per_device_batch_size,
batch_size=per_device_batch_size, num_epochs=flags_obj.train_epochs,
num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras)
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn( eval_input_dataset = input_fn(is_training=False,
is_training=False, data_dir=flags_obj.data_dir,
data_dir=flags_obj.data_dir, batch_size=per_device_batch_size,
batch_size=per_device_batch_size, num_epochs=flags_obj.train_epochs,
num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras)
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
...@@ -142,29 +143,29 @@ def run(flags_obj): ...@@ -142,29 +143,29 @@ def run(flags_obj):
train_epochs = 1 train_epochs = 1
num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
history = model.fit(train_input_dataset, model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, lr_callback,
tensorboard_callback tensorboard_callback
], ],
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=validation_data, validation_data=validation_data,
verbose=1) verbose=1)
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset, model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
def main(_): def main(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment