Unverified Commit 58a3de6c authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Add FP16 end-to-end tests (#7122)

parent 44ff121d
...@@ -284,14 +284,66 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -284,14 +284,66 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
FLAGS.max_length = 64 FLAGS.max_length = 64
FLAGS.train_steps = 100000 FLAGS.train_steps = 400000
FLAGS.steps_between_evals = 5000 FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29)
def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.train_steps = 400000
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_xla_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch, XLA, and FP16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.train_steps = 400000
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_static_batch_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
class TransformerKerasBenchmark(TransformerBenchmark): class TransformerKerasBenchmark(TransformerBenchmark):
"""Benchmarks for Transformer (Base and Big) using Keras.""" """Benchmarks for Transformer (Base and Big) using Keras."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment