"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "0b8bfa849c2d6fef455de34fd8da1941ea08cbd9"
Commit 9bf586de authored by Toby Boyd's avatar Toby Boyd Committed by Taylor Robie
Browse files

Add 5 epoch warmup to resnet (#5176)

* Add 5 epoch warmup

* get_lr with warm_up only for imagenet

* Add base_lr, remove fp16 unittest arg validation

* Remove validation check stopping v1 and FP16
parent 981c0039
...@@ -165,13 +165,6 @@ class BaseTest(tf.test.TestCase): ...@@ -165,13 +165,6 @@ class BaseTest(tf.test.TestCase):
extra_flags=['-resnet_version', '2'] extra_flags=['-resnet_version', '2']
) )
def test_flag_restriction(self):
with self.assertRaises(SystemExit):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', "-dtype", "fp16"]
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -285,10 +285,20 @@ def _get_block_sizes(resnet_size): ...@@ -285,10 +285,20 @@ def _get_block_sizes(resnet_size):
def imagenet_model_fn(features, labels, mode, params): def imagenet_model_fn(features, labels, mode, params):
"""Our model_fn for ResNet to be used with our Estimator.""" """Our model_fn for ResNet to be used with our Estimator."""
# Warmup and higher lr may not be valid for fine tuning with small batches
# and smaller numbers of training images.
if params['fine_tune']:
warmup = False
base_lr = .1
else:
warmup = True
base_lr = .128
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256, batch_size=params['batch_size'], batch_denom=256,
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)
return resnet_run_loop.resnet_model_fn( return resnet_run_loop.resnet_model_fn(
features=features, features=features,
......
...@@ -304,13 +304,6 @@ class BaseTest(tf.test.TestCase): ...@@ -304,13 +304,6 @@ class BaseTest(tf.test.TestCase):
extra_flags=['-resnet_version', '2', '-resnet_size', '200'] extra_flags=['-resnet_version', '2', '-resnet_size', '200']
) )
def test_flag_restriction(self):
with self.assertRaises(SystemExit):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-dtype', 'fp16']
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -138,7 +138,8 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -138,7 +138,8 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
# Functions for running training/eval/validation loops for the model. # Functions for running training/eval/validation loops for the model.
################################################################################ ################################################################################
def learning_rate_with_decay( def learning_rate_with_decay(
batch_size, batch_denom, num_images, boundary_epochs, decay_rates): batch_size, batch_denom, num_images, boundary_epochs, decay_rates,
base_lr=0.1, warmup=False):
"""Get a learning rate that decays step-wise as training progresses. """Get a learning rate that decays step-wise as training progresses.
Args: Args:
...@@ -152,13 +153,14 @@ def learning_rate_with_decay( ...@@ -152,13 +153,14 @@ def learning_rate_with_decay(
decay_rates: list of floats representing the decay rates to be used decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. It should have one more element for scaling the learning rate. It should have one more element
than `boundary_epochs`, and all elements should have the same type. than `boundary_epochs`, and all elements should have the same type.
base_lr: Initial learning rate scaled based on batch_denom.
warmup: Run a 5 epoch warmup to the initial lr.
Returns: Returns:
Returns a function that takes a single argument - the number of batches Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used trained so far (global_step)- and returns the learning rate to be used
for training the next batch. for training the next batch.
""" """
initial_learning_rate = 0.1 * batch_size / batch_denom initial_learning_rate = base_lr * batch_size / batch_denom
batches_per_epoch = num_images / batch_size batches_per_epoch = num_images / batch_size
# Reduce the learning rate at certain epochs. # Reduce the learning rate at certain epochs.
...@@ -168,8 +170,15 @@ def learning_rate_with_decay( ...@@ -168,8 +170,15 @@ def learning_rate_with_decay(
vals = [initial_learning_rate * decay for decay in decay_rates] vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(global_step): def learning_rate_fn(global_step):
global_step = tf.cast(global_step, tf.int32) """Builds scaled learning rate function with 5 epoch warm up."""
return tf.train.piecewise_constant(global_step, boundaries, vals) lr = tf.train.piecewise_constant(global_step, boundaries, vals)
if warmup:
warmup_steps = int(batches_per_epoch * 5)
warmup_lr = (
initial_learning_rate * tf.cast(global_step, tf.float32) / tf.cast(
warmup_steps, tf.float32))
return tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
return lr
return learning_rate_fn return learning_rate_fn
...@@ -499,12 +508,3 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -499,12 +508,3 @@ def define_resnet_flags(resnet_size_choices=None):
flags.DEFINE_string(**choice_kwargs) flags.DEFINE_string(**choice_kwargs)
else: else:
flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs) flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
# The current implementation of ResNet v1 is numerically unstable when run
# with fp16 and will produce NaN errors soon after training begins.
msg = ('ResNet version 1 is not currently supported with fp16. '
'Please use version 2 instead.')
@flags.multi_flags_validator(['dtype', 'resnet_version'], message=msg)
def _forbid_v1_fp16(flag_values): # pylint: disable=unused-variable
return (flags_core.DTYPE_MAP[flag_values['dtype']][0] != tf.float16 or
flag_values['resnet_version'] != '1')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment