Unverified Commit 17e923da authored by Reed's avatar Reed Committed by GitHub
Browse files

Add dynamic loss scaling support (#6518)

parent cc9eef76
...@@ -343,9 +343,10 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -343,9 +343,10 @@ def imagenet_model_fn(features, labels, mode, params):
) )
def define_imagenet_flags(): def define_imagenet_flags(dynamic_loss_scale=False):
resnet_run_loop.define_resnet_flags( resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200']) resnet_size_choices=['18', '34', '50', '101', '152', '200'],
dynamic_loss_scale=dynamic_loss_scale)
flags.adopt_module_key_flags(resnet_run_loop) flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90) flags_core.set_defaults(train_epochs=90)
......
...@@ -235,6 +235,6 @@ def main(_): ...@@ -235,6 +235,6 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
imagenet_main.define_imagenet_flags() imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
keras_common.define_keras_flags() keras_common.define_keras_flags()
absl_app.run(main) absl_app.run(main)
...@@ -707,13 +707,14 @@ def resnet_main( ...@@ -707,13 +707,14 @@ def resnet_main(
return stats return stats
def define_resnet_flags(resnet_size_choices=None): def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base() flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True, tf_gpu_thread_mode=True,
datasets_num_private_threads=True, datasets_num_private_threads=True,
datasets_num_parallel_batches=True) datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -38,8 +38,10 @@ def get_tf_dtype(flags_obj): ...@@ -38,8 +38,10 @@ def get_tf_dtype(flags_obj):
def get_loss_scale(flags_obj): def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None: if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale return flags_obj.loss_scale
elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale)
return DTYPE_MAP[flags_obj.dtype][1] return DTYPE_MAP[flags_obj.dtype][1]
...@@ -47,7 +49,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -47,7 +49,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True, synthetic_data=True, max_train_steps=True, dtype=True,
all_reduce_alg=True, tf_gpu_thread_mode=False, all_reduce_alg=True, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
datasets_num_parallel_batches=False): datasets_num_parallel_batches=False,
dynamic_loss_scale=False):
"""Register flags for specifying performance tuning arguments. """Register flags for specifying performance tuning arguments.
Args: Args:
...@@ -63,6 +66,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -63,6 +66,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
datasets_num_private_threads: Number of private threads for datasets. datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data. parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
...@@ -117,24 +122,46 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -117,24 +122,46 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"Variables may be cast to a higher precision on a " "Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability.")) "case-by-case basis for numerical stability."))
flags.DEFINE_integer( loss_scale_help_text = (
name="loss_scale", short_name="ls", default=None, "The amount to scale the loss by when the model is run. {}. Before "
help=help_wrap(
"The amount to scale the loss by when the model is run. Before "
"gradients are computed, the loss is multiplied by the loss scale, " "gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, " "making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to " "gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without " "variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate " "a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default " "gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes.")) "for fp16 is 128 and 1 for all other dtypes.{}"
)
if dynamic_loss_scale:
loss_scale_help_text = loss_scale_help_text.format(
"This can be an int/float or the string 'dynamic'",
" The string 'dynamic' can be used to dynamically determine the "
"optimal loss scale during training, but currently this "
"significantly slows down performance")
loss_scale_validation_msg = ("loss_scale should be a positive int/float "
"or the string 'dynamic'.")
else:
loss_scale_help_text = loss_scale_help_text.format(
"This must be an int/float", "")
loss_scale_validation_msg = "loss_scale should be a positive int/float."
flags.DEFINE_string(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(loss_scale_help_text))
loss_scale_val_msg = "loss_scale should be a positive integer." @flags.validator(flag_name="loss_scale", message=loss_scale_validation_msg)
@flags.validator(flag_name="loss_scale", message=loss_scale_val_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid"""
if loss_scale is None: if loss_scale is None:
return True # null case is handled in get_loss_scale() return True # null case is handled in get_loss_scale()
if loss_scale == "dynamic" and dynamic_loss_scale:
return True
try:
loss_scale = float(loss_scale)
except ValueError:
return False
return loss_scale > 0 return loss_scale > 0
if all_reduce_alg: if all_reduce_alg:
......
...@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base(num_gpu=False) flags_core.define_base(num_gpu=False)
flags_core.define_performance() flags_core.define_performance(dynamic_loss_scale=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
...@@ -89,12 +89,20 @@ class BaseTester(unittest.TestCase): ...@@ -89,12 +89,20 @@ class BaseTester(unittest.TestCase):
flags_core.parse_flags( flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "5"]) [__file__, "--dtype", dtype_str, "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5) self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), "dynamic")
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"]) flags_core.parse_flags([__file__, "--dtype", "int8"])
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "fp16",
"--loss_scale", "abc"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -47,6 +47,8 @@ ...@@ -47,6 +47,8 @@
def get_loss_scale(flags_obj): def get_loss_scale(flags_obj):
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
if flags_obj.loss_scale is not None: if flags_obj.loss_scale is not None:
return flags_obj.loss_scale return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1] return DTYPE_MAP[flags_obj.dtype][1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment