Commit c9c05e9b authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

fix transformer amp

parent c186c85a
......@@ -72,7 +72,8 @@ def define_transformer_flags():
loss_scale=True,
all_reduce_alg=True,
enable_xla=True,
force_v2_in_keras_compile=True
force_v2_in_keras_compile=True,
fp16_implementation=True
)
# Additional performance flags
......@@ -85,9 +86,6 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags.DEFINE_boolean(
name='automatic_mixed_precision', default=False,
help='Enable automatic mixed precision training via a graph rewrite.')
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
......
......@@ -330,6 +330,30 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_fp16_amp(self):
"""Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16.
......
......@@ -376,11 +376,12 @@ class TransformerTask(object):
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
if self.flags_obj.automatic_mixed_precision:
if params["dtype"] == tf.float16:
raise RuntimeError("Automatic mixed precision should not be called in conjunction with "
"other types of mixed precision training. Set --dtype=fp32 instead.")
opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(opt)
if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# does not double up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt
......
......@@ -190,9 +190,9 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
return loss_scale > 0
if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model.
# Currently, this flag is only defined for the estimator resnet and transformer models.
flags.DEFINE_enum(
name="fp16_implementation", default="graph_rewrite",
name="fp16_implementation", default="casting",
enum_values=("casting', 'graph_rewrite"),
help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment