Unverified Commit b691578c authored by Reed's avatar Reed Committed by GitHub
Browse files

Add --fp16_implementation option. (#6703)

This options allows the new tf.train.experimental.enable_mixed_precision_graph_rewrite() function to be used for fp16, instead of manual casts.
parent 62ed862e
......@@ -347,10 +347,11 @@ def imagenet_model_fn(features, labels, mode, params):
)
def define_imagenet_flags(dynamic_loss_scale=False):
def define_imagenet_flags(dynamic_loss_scale=False, fp16_implementation=False):
resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200'],
dynamic_loss_scale=dynamic_loss_scale)
dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation)
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90)
......@@ -385,5 +386,5 @@ def main(_):
if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_imagenet_flags()
define_imagenet_flags(dynamic_loss_scale=True, fp16_implementation=True)
absl_app.run(main)
......@@ -451,6 +451,11 @@ def resnet_model_fn(features, labels, mode, model_class,
momentum=momentum
)
fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None)
if fp16_implementation == 'graph_rewrite':
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale=loss_scale)
def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer.
......@@ -463,7 +468,7 @@ def resnet_model_fn(features, labels, mode, model_class,
"""
return [(g, v) for g, v in gvs if 'dense' in v.name]
if loss_scale != 1:
if loss_scale != 1 and fp16_implementation != 'graph_rewrite':
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
......@@ -708,14 +713,16 @@ def resnet_main(
return stats
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False):
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False):
"""Add flags and validators for ResNet."""
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale)
dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation)
flags_core.define_image()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
......
......@@ -34,6 +34,10 @@ DTYPE_MAP = {
def get_tf_dtype(flags_obj):
if getattr(flags_obj, 'fp16_implementation', None) == 'graph_rewrite':
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype][0]
......@@ -51,7 +55,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False,
dynamic_loss_scale=False):
dynamic_loss_scale=False, fp16_implementation=False):
"""Register flags for specifying performance tuning arguments.
Args:
......@@ -71,6 +75,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
Returns:
A list of flags for core.py to marks as key flags.
......@@ -167,6 +172,33 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
return loss_scale > 0
if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model.
flags.DEFINE_enum(
name="fp16_implementation", default='casting',
enum_values=('casting', 'graph_rewrite'),
help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'casting' will cause manual tf.casts to "
"be inserted in the model. 'graph_rewrite' means "
"tf.train.experimental.enable_mixed_precision_graph_rewrite will "
"be used to automatically use fp16 without any manual casts."))
@flags.multi_flags_validator(['fp16_implementation', 'dtype',
'loss_scale'])
def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid."""
if (flags_dict['fp16_implementation'] == 'graph_rewrite' and
flags_dict['dtype'] != 'fp16'):
raise flags.ValidationError('--fp16_implementation should not be '
'specified unless --dtype=fp16')
if (flags_dict['fp16_implementation'] != 'graph_rewrite' and
flags_dict['loss_scale'] == 'dynamic'):
raise flags.ValidationError('--loss_scale=dynamic is only supported '
'when '
'--fp16_implementation=graph_rewrite')
return True
if all_reduce_alg:
flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment