Unverified Commit 17e923da authored by Reed's avatar Reed Committed by GitHub
Browse files

Add dynamic loss scaling support (#6518)

parent cc9eef76
......@@ -343,9 +343,10 @@ def imagenet_model_fn(features, labels, mode, params):
)
def define_imagenet_flags():
def define_imagenet_flags(dynamic_loss_scale=False):
resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200'])
resnet_size_choices=['18', '34', '50', '101', '152', '200'],
dynamic_loss_scale=dynamic_loss_scale)
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90)
......
......@@ -235,6 +235,6 @@ def main(_):
if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
imagenet_main.define_imagenet_flags()
imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
keras_common.define_keras_flags()
absl_app.run(main)
......@@ -707,13 +707,14 @@ def resnet_main(
return stats
def define_resnet_flags(resnet_size_choices=None):
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False):
"""Add flags and validators for ResNet."""
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
datasets_num_parallel_batches=True)
datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale)
flags_core.define_image()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
......
......@@ -38,8 +38,10 @@ def get_tf_dtype(flags_obj):
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale)
return DTYPE_MAP[flags_obj.dtype][1]
......@@ -47,7 +49,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True,
all_reduce_alg=True, tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False):
datasets_num_parallel_batches=False,
dynamic_loss_scale=False):
"""Register flags for specifying performance tuning arguments.
Args:
......@@ -63,6 +66,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
Returns:
A list of flags for core.py to marks as key flags.
......@@ -117,24 +122,46 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."))
flags.DEFINE_integer(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(
"The amount to scale the loss by when the model is run. Before "
loss_scale_help_text = (
"The amount to scale the loss by when the model is run. {}. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes."))
"for fp16 is 128 and 1 for all other dtypes.{}"
)
if dynamic_loss_scale:
loss_scale_help_text = loss_scale_help_text.format(
"This can be an int/float or the string 'dynamic'",
" The string 'dynamic' can be used to dynamically determine the "
"optimal loss scale during training, but currently this "
"significantly slows down performance")
loss_scale_validation_msg = ("loss_scale should be a positive int/float "
"or the string 'dynamic'.")
else:
loss_scale_help_text = loss_scale_help_text.format(
"This must be an int/float", "")
loss_scale_validation_msg = "loss_scale should be a positive int/float."
flags.DEFINE_string(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(loss_scale_help_text))
loss_scale_val_msg = "loss_scale should be a positive integer."
@flags.validator(flag_name="loss_scale", message=loss_scale_val_msg)
@flags.validator(flag_name="loss_scale", message=loss_scale_validation_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid"""
if loss_scale is None:
return True # null case is handled in get_loss_scale()
if loss_scale == "dynamic" and dynamic_loss_scale:
return True
try:
loss_scale = float(loss_scale)
except ValueError:
return False
return loss_scale > 0
if all_reduce_alg:
......
......@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags():
flags_core.define_base(num_gpu=False)
flags_core.define_performance()
flags_core.define_performance(dynamic_loss_scale=True)
flags_core.define_image()
flags_core.define_benchmark()
......@@ -89,12 +89,20 @@ class BaseTester(unittest.TestCase):
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), "dynamic")
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"])
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "fp16",
"--loss_scale", "abc"])
if __name__ == "__main__":
unittest.main()
......@@ -47,6 +47,8 @@
def get_loss_scale(flags_obj):
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment