Commit b578aee9 authored by Reed's avatar Reed Committed by Toby Boyd
Browse files

Fix Transformer Perfzero issue with fp16 (#7074)

parent adc8f11b
......@@ -120,6 +120,15 @@ class TransformerTask(object):
params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj)
if params["dtype"] == tf.float16:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
def train(self):
"""Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True
......@@ -263,11 +272,6 @@ def _ensure_dir(log_dir):
def main(_):
flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj):
if flags_core.get_tf_dtype(flags_obj) == 'float16':
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
task = TransformerTask(flags_obj)
if flags_obj.mode == "train":
task.train()
......
......@@ -92,9 +92,6 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.num_gpus = 2
FLAGS.param_set = "base"
FLAGS.dtype = "fp16"
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
t = tm.TransformerTask(FLAGS)
t.train()
......@@ -131,9 +128,6 @@ class TransformerTaskTest(tf.test.TestCase):
def test_predict_fp16(self):
self._prepare_files_and_flags("--dtype=fp16")
policy = tf.keras.mixed_precision.experimental.Policy(
'infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
t = tm.TransformerTask(FLAGS)
t.predict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment