Commit 383c6e30 authored by guptapriya's avatar guptapriya Committed by Toby Boyd
Browse files

Add static batch benchmarks to estimator (#6886)

* Add static batch benchmarks to estimator 

So we can distinguish how much static vs dynamic batch matter.

* change max_length for static_batch tests

* Add flag for max length
parent 3928d481
......@@ -159,6 +159,28 @@ class TransformerBigEstimatorAccuracy(EstimatorBenchmark):
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def benchmark_graph_8_gpu_static_batch(self):
"""Benchmark graph mode 8 gpus.
SOTA is 28.4 BLEU (uncased).
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072 * 8
FLAGS.static_batch = True
FLAGS.max_length = 40
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def _run_and_report_benchmark(self, bleu_min=28.3, bleu_max=29):
"""Run benchmark and report results.
......@@ -254,6 +276,31 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def benchmark_graph_8_gpu_static_batch(self):
"""Benchmark graph mode 8 gpus.
SOTA is 27.3 BLEU (uncased).
Best so far is 27.2 with 4048*8 at 75,000 steps.
27.009 with 4096*8 at 100,000 steps and earlier.
Other test: 2024 * 8 peaked at 26.66 at 100,000 steps.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 * 8
FLAGS.static_batch = True
FLAGS.max_length = 40
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def benchmark_graph_fp16_8_gpu(self):
"""benchmark 8 gpus with fp16 mixed precision.
......
......@@ -390,6 +390,10 @@ def run_loop(
def define_transformer_flags():
"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.).
flags.DEFINE_integer(
name="max_length", short_name="ml", default=None,
help=flags_core.help_wrap("Max length."))
flags_core.define_base()
flags_core.define_performance(
num_parallel_calls=True,
......@@ -579,6 +583,8 @@ def run_transformer(flags_obj):
params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
params["allow_ffn_pad"] = not params["use_tpu"]
params["max_length"] = flags_obj.max_length or params['max_length']
params["use_synthetic_data"] = flags_obj.use_synthetic_data
# Set batch size parameter, which depends on the availability of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment