Commit 6c965160 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

Merge remote-tracking branch 'origin/amp_transformer' into amp_resnet50

parents 901c4cc4 b8f58ec8
...@@ -85,7 +85,10 @@ def define_transformer_flags(): ...@@ -85,7 +85,10 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to ' 'convolutions and batch normalizations, and this flag allows to '
'disable it.' 'disable it.'
) )
flags.DEFINE_boolean(
name='automatic_mixed_precision', default=False,
help='Enable automatic mixed precision training via a graph rewrite.')
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
......
...@@ -376,6 +376,12 @@ class TransformerTask(object): ...@@ -376,6 +376,12 @@ class TransformerTask(object):
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer( opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj, opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic")) default_for_fp16="dynamic"))
if self.flags_obj.automatic_mixed_precision:
if params["dtype"] == tf.float16:
raise RuntimeError("Automatic mixed precision should not be called in conjunction with "
"other types of mixed precision training. Set --dtype=fp32 instead.")
opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment