Commit c9c05e9b authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

fix transformer amp

parent c186c85a
...@@ -72,7 +72,8 @@ def define_transformer_flags(): ...@@ -72,7 +72,8 @@ def define_transformer_flags():
loss_scale=True, loss_scale=True,
all_reduce_alg=True, all_reduce_alg=True,
enable_xla=True, enable_xla=True,
force_v2_in_keras_compile=True force_v2_in_keras_compile=True,
fp16_implementation=True
) )
# Additional performance flags # Additional performance flags
...@@ -85,9 +86,6 @@ def define_transformer_flags(): ...@@ -85,9 +86,6 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to ' 'convolutions and batch normalizations, and this flag allows to '
'disable it.' 'disable it.'
) )
flags.DEFINE_boolean(
name='automatic_mixed_precision', default=False,
help='Enable automatic mixed precision training via a graph rewrite.')
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
......
...@@ -330,6 +330,30 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -330,6 +330,30 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29)
def benchmark_8_gpu_fp16_amp(self):
"""Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_static_batch_fp16(self): def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16. """Benchmark 8 gpu with static batch and fp16.
......
...@@ -376,11 +376,12 @@ class TransformerTask(object): ...@@ -376,11 +376,12 @@ class TransformerTask(object):
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer( opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj, opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic")) default_for_fp16="dynamic"))
if self.flags_obj.automatic_mixed_precision: if self.flags_obj.fp16_implementation == "graph_rewrite":
if params["dtype"] == tf.float16: # Note: when flags_obj.fp16_implementation == "graph_rewrite",
raise RuntimeError("Automatic mixed precision should not be called in conjunction with " # dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
"other types of mixed precision training. Set --dtype=fp32 instead.") # which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(opt) # does not double up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
...@@ -190,9 +190,9 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -190,9 +190,9 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
return loss_scale > 0 return loss_scale > 0
if fp16_implementation: if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model. # Currently, this flag is only defined for the estimator resnet and transformer models.
flags.DEFINE_enum( flags.DEFINE_enum(
name="fp16_implementation", default="graph_rewrite", name="fp16_implementation", default="casting",
enum_values=("casting', 'graph_rewrite"), enum_values=("casting', 'graph_rewrite"),
help=help_wrap( help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no " "When --dtype=fp16, how fp16 should be implemented. This has no "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment