Commit b8f58ec8 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

adding automatic mixed precision training to transformer v2

parent 58340818
......@@ -85,7 +85,10 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags.DEFINE_boolean(
name='automatic_mixed_precision', default=False,
help='Enable automatic mixed precision training via a graph rewrite.')
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
......
......@@ -259,6 +259,12 @@ class TransformerTask(object):
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
if self.flags_obj.automatic_mixed_precision:
if params["dtype"] == tf.float16:
raise RuntimeError("Automatic mixed precision should not be called in conjunction with "
"other types of mixed precision training. Set --dtype=fp32 instead.")
opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment