Unverified Commit b7e97bec authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Add FP16 to transformer with benchmark tests. (#6756)

* Add FP16 and benchmarks.

* add missing run and report.

* Add loss_scale as option not included with dtype.

* move loss_scale validation under dtype conditional.

* add loss_scale to flags tested.
parent c0b31c51
......@@ -722,7 +722,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
datasets_num_private_threads=True,
datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation)
fp16_implementation=fp16_implementation,
loss_scale=True)
flags_core.define_image()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
......
......@@ -139,35 +139,39 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
super(TransformerBaseEstimatorAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
def benchmark_graph_1_gpu(self):
"""Benchmark graph mode 1 gpu.
def benchmark_graph_2_gpu(self):
"""Benchmark graph mode 2 gpus.
The paper uses 8 GPUs and a much larger effective batch size, this is will
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
FLAGS.num_gpus = 1
FLAGS.num_gpus = 2
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096
FLAGS.batch_size = 4096 * 2
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
# These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(bleu_min=25.3, bleu_max=26)
def benchmark_graph_2_gpu(self):
"""Benchmark graph mode 2 gpus.
def benchmark_graph_fp16_2_gpu(self):
"""Benchmark 2 gpu with fp16 mixed-precision.
The paper uses 8 GPUs and a much larger effective batch size, this is will
not converge to the 27.3 BLEU (uncased) SOTA.
The paper uses 8 GPUs and a much larger effective batch-size,
this is unlikely to hit the target bleu score regardless of
number of steps.
"""
self._setup()
FLAGS.num_gpus = 2
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
......@@ -177,14 +181,17 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
FLAGS.batch_size = 4096 * 2
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_2_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
# These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(bleu_min=25.3, bleu_max=26)
def benchmark_graph_8_gpu(self):
"""Benchmark graph mode 8 gpus.
SOTA is 27.3 BLEU (uncased).
Best so far is 27.2 with 4048 * 8 at 75,000 steps.
Other test: 2024 * 8 peaked at 26.66 at 100,000 steps.
"""
self._setup()
FLAGS.num_gpus = 8
......@@ -194,21 +201,48 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 * 8
FLAGS.batch_size = 3072 * 8
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
def benchmark_graph_fp16_8_gpu(self):
"""benchmark 8 gpus with fp16 mixed precision.
SOTA is 27.3 BLEU (uncased).
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 3072 * 8
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def _run_and_report_benchmark(self, bleu_min=27.3, bleu_max=28):
"""Run benchmark and report results.
Args:
bleu_min: minimum expected uncased bleu. default is SOTA.
bleu_max: max expected uncased bleu. default is a high number.
"""
start_time_sec = time.time()
stats = transformer_main.run_transformer(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec
self._report_benchmark(stats,
wall_time_sec,
bleu_min=27.2,
bleu_max=28)
bleu_min=bleu_min,
bleu_max=bleu_max)
class TransformerBaseEstimatorBenchmark(EstimatorBenchmark):
......@@ -227,34 +261,70 @@ class TransformerBaseEstimatorBenchmark(EstimatorBenchmark):
"""Benchmark graph 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 2048
FLAGS.batch_size = 4096
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
self._run_and_report_benchmark()
def benchmark_graph_fp16_1_gpu(self):
"""Benchmark graph fp16 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_1_gpu')
self._run_and_report_benchmark()
def benchmark_graph_2_gpu(self):
"""Benchmark graph 2 gpus."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.batch_size = 2048 * 2
FLAGS.batch_size = 4096 * 2
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
self._run_and_report_benchmark()
def benchmark_graph_fp16_2_gpu(self):
"""Benchmark graph fp16 2 gpus."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096 * 2
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_2_gpu')
self._run_and_report_benchmark()
def benchmark_graph_4_gpu(self):
"""Benchmark graph 4 gpus."""
self._setup()
FLAGS.num_gpus = 4
FLAGS.batch_size = 2048 * 4
FLAGS.batch_size = 4096 * 4
FLAGS.model_dir = self._get_model_dir('benchmark_graph_4_gpu')
self._run_and_report_benchmark()
def benchmark_graph_fp16_4_gpu(self):
"""Benchmark 4 graph fp16 gpus."""
self._setup()
FLAGS.num_gpus = 4
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096 * 4
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_4_gpu')
self._run_and_report_benchmark()
def benchmark_graph_8_gpu(self):
"""Benchmark graph 8 gpus."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = 2048 * 8
FLAGS.batch_size = 4096 * 8
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
self._run_and_report_benchmark()
def benchmark_graph_fp16_8_gpu(self):
"""Benchmark graph fp16 8 gpus."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096 * 8
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu')
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = transformer_main.run_transformer(flags.FLAGS)
......
......@@ -182,6 +182,11 @@ def get_train_op_and_metrics(loss, params):
if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
# Uses automatic mixed precision FP16 training if on GPU.
if params["dtype"] == "fp16":
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# Calculate and apply gradients using LazyAdamOptimizer.
global_step = tf.train.get_global_step()
tvars = tf.trainable_variables()
......@@ -232,8 +237,8 @@ def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file):
uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref)
tf.logging.info("Bleu score (uncased): %d", uncased_score)
tf.logging.info("Bleu score (cased): %d", cased_score)
tf.logging.info("Bleu score (uncased): %f", uncased_score)
tf.logging.info("Bleu score (cased): %f", cased_score)
return uncased_score, cased_score
......@@ -392,7 +397,7 @@ def define_transformer_flags():
intra_op=False,
synthetic_data=True,
max_train_steps=False,
dtype=False,
dtype=True,
all_reduce_alg=True
)
flags_core.define_benchmark()
......
......@@ -34,7 +34,7 @@ DTYPE_MAP = {
def get_tf_dtype(flags_obj):
if getattr(flags_obj, 'fp16_implementation', None) == 'graph_rewrite':
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
......@@ -55,7 +55,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False,
dynamic_loss_scale=False, fp16_implementation=False):
dynamic_loss_scale=False, fp16_implementation=False,
loss_scale=False):
"""Register flags for specifying performance tuning arguments.
Args:
......@@ -76,6 +77,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
loss_scale: Controls the loss scaling, normally for mixed-precision
training. Can only be turned on if dtype is also True.
Returns:
A list of flags for core.py to marks as key flags.
......@@ -152,13 +155,15 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
loss_scale_help_text = loss_scale_help_text.format(
"This must be an int/float", "")
loss_scale_validation_msg = "loss_scale should be a positive int/float."
if loss_scale:
flags.DEFINE_string(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(loss_scale_help_text))
@flags.validator(flag_name="loss_scale", message=loss_scale_validation_msg)
@flags.validator(flag_name="loss_scale",
message=loss_scale_validation_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid"""
"""Validator to check the loss scale flag is valid."""
if loss_scale is None:
return True # null case is handled in get_loss_scale()
......@@ -175,8 +180,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model.
flags.DEFINE_enum(
name="fp16_implementation", default='casting',
enum_values=('casting', 'graph_rewrite'),
name="fp16_implementation", default="casting",
enum_values=("casting', 'graph_rewrite"),
help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'casting' will cause manual tf.casts to "
......@@ -184,19 +189,19 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"tf.train.experimental.enable_mixed_precision_graph_rewrite will "
"be used to automatically use fp16 without any manual casts."))
@flags.multi_flags_validator(['fp16_implementation', 'dtype',
'loss_scale'])
@flags.multi_flags_validator(["fp16_implementation", "dtype",
"loss_scale"])
def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid."""
if (flags_dict['fp16_implementation'] == 'graph_rewrite' and
flags_dict['dtype'] != 'fp16'):
raise flags.ValidationError('--fp16_implementation should not be '
'specified unless --dtype=fp16')
if (flags_dict['fp16_implementation'] != 'graph_rewrite' and
flags_dict['loss_scale'] == 'dynamic'):
raise flags.ValidationError('--loss_scale=dynamic is only supported '
'when '
'--fp16_implementation=graph_rewrite')
if (flags_dict["fp16_implementation"] == "graph_rewrite" and
flags_dict["dtype"] != "fp16"):
raise flags.ValidationError("--fp16_implementation should not be "
"specified unless --dtype=fp16")
if (flags_dict["fp16_implementation"] != "graph_rewrite" and
flags_dict["loss_scale"] == "dynamic"):
raise flags.ValidationError("--loss_scale=dynamic is only supported "
"when "
"--fp16_implementation=graph_rewrite")
return True
if all_reduce_alg:
......
......@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags():
flags_core.define_base(num_gpu=False)
flags_core.define_performance(dynamic_loss_scale=True)
flags_core.define_performance(dynamic_loss_scale=True, loss_scale=True)
flags_core.define_image()
flags_core.define_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment