Unverified Commit b691578c authored by Reed's avatar Reed Committed by GitHub
Browse files

Add --fp16_implementation option. (#6703)

This options allows the new tf.train.experimental.enable_mixed_precision_graph_rewrite() function to be used for fp16, instead of manual casts.
parent 62ed862e
...@@ -347,10 +347,11 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -347,10 +347,11 @@ def imagenet_model_fn(features, labels, mode, params):
) )
def define_imagenet_flags(dynamic_loss_scale=False): def define_imagenet_flags(dynamic_loss_scale=False, fp16_implementation=False):
resnet_run_loop.define_resnet_flags( resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200'], resnet_size_choices=['18', '34', '50', '101', '152', '200'],
dynamic_loss_scale=dynamic_loss_scale) dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation)
flags.adopt_module_key_flags(resnet_run_loop) flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90) flags_core.set_defaults(train_epochs=90)
...@@ -385,5 +386,5 @@ def main(_): ...@@ -385,5 +386,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_imagenet_flags() define_imagenet_flags(dynamic_loss_scale=True, fp16_implementation=True)
absl_app.run(main) absl_app.run(main)
...@@ -451,6 +451,11 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -451,6 +451,11 @@ def resnet_model_fn(features, labels, mode, model_class,
momentum=momentum momentum=momentum
) )
fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None)
if fp16_implementation == 'graph_rewrite':
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale=loss_scale)
def _dense_grad_filter(gvs): def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer. """Only apply gradient updates to the final layer.
...@@ -463,7 +468,7 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -463,7 +468,7 @@ def resnet_model_fn(features, labels, mode, model_class,
""" """
return [(g, v) for g, v in gvs if 'dense' in v.name] return [(g, v) for g, v in gvs if 'dense' in v.name]
if loss_scale != 1: if loss_scale != 1 and fp16_implementation != 'graph_rewrite':
# When computing fp16 gradients, often intermediate tensor values are # When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by # so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger. # loss_scale to make these tensor values loss_scale times bigger.
...@@ -708,14 +713,16 @@ def resnet_main( ...@@ -708,14 +713,16 @@ def resnet_main(
return stats return stats
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False): def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base() flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True, tf_gpu_thread_mode=True,
datasets_num_private_threads=True, datasets_num_private_threads=True,
datasets_num_parallel_batches=True, datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale) dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -34,6 +34,10 @@ DTYPE_MAP = { ...@@ -34,6 +34,10 @@ DTYPE_MAP = {
def get_tf_dtype(flags_obj): def get_tf_dtype(flags_obj):
if getattr(flags_obj, 'fp16_implementation', None) == 'graph_rewrite':
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype][0] return DTYPE_MAP[flags_obj.dtype][0]
...@@ -51,7 +55,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -51,7 +55,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
tf_gpu_thread_mode=False, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
datasets_num_parallel_batches=False, datasets_num_parallel_batches=False,
dynamic_loss_scale=False): dynamic_loss_scale=False, fp16_implementation=False):
"""Register flags for specifying performance tuning arguments. """Register flags for specifying performance tuning arguments.
Args: Args:
...@@ -71,6 +75,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -71,6 +75,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
parallel when using map and batch from tf.data. parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True. "dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
...@@ -167,6 +172,33 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -167,6 +172,33 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
return loss_scale > 0 return loss_scale > 0
if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model.
flags.DEFINE_enum(
name="fp16_implementation", default='casting',
enum_values=('casting', 'graph_rewrite'),
help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'casting' will cause manual tf.casts to "
"be inserted in the model. 'graph_rewrite' means "
"tf.train.experimental.enable_mixed_precision_graph_rewrite will "
"be used to automatically use fp16 without any manual casts."))
@flags.multi_flags_validator(['fp16_implementation', 'dtype',
'loss_scale'])
def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid."""
if (flags_dict['fp16_implementation'] == 'graph_rewrite' and
flags_dict['dtype'] != 'fp16'):
raise flags.ValidationError('--fp16_implementation should not be '
'specified unless --dtype=fp16')
if (flags_dict['fp16_implementation'] != 'graph_rewrite' and
flags_dict['loss_scale'] == 'dynamic'):
raise flags.ValidationError('--loss_scale=dynamic is only supported '
'when '
'--fp16_implementation=graph_rewrite')
return True
if all_reduce_alg: if all_reduce_alg:
flags.DEFINE_string( flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None, name="all_reduce_alg", short_name="ara", default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment