Unverified Commit b7e97bec authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Add FP16 to transformer with benchmark tests. (#6756)

* Add FP16 and benchmarks.

* add missing run and report.

* Add loss_scale as option not included with dtype.

* move loss_scale validation under dtype conditional.

* add loss_scale to flags tested.
parent c0b31c51
...@@ -722,7 +722,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, ...@@ -722,7 +722,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
datasets_num_private_threads=True, datasets_num_private_threads=True,
datasets_num_parallel_batches=True, datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation) fp16_implementation=fp16_implementation,
loss_scale=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -139,35 +139,39 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark): ...@@ -139,35 +139,39 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
super(TransformerBaseEstimatorAccuracy, self).__init__( super(TransformerBaseEstimatorAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
def benchmark_graph_1_gpu(self): def benchmark_graph_2_gpu(self):
"""Benchmark graph mode 1 gpu. """Benchmark graph mode 2 gpus.
The paper uses 8 GPUs and a much larger effective batch size, this is will The paper uses 8 GPUs and a much larger effective batch size, this is will
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 2
FLAGS.data_dir = self.train_data_dir FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check. # Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096 * 2
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000 FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook'] FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark() # These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(bleu_min=25.3, bleu_max=26)
def benchmark_graph_2_gpu(self): def benchmark_graph_fp16_2_gpu(self):
"""Benchmark graph mode 2 gpus. """Benchmark 2 gpu with fp16 mixed-precision.
The paper uses 8 GPUs and a much larger effective batch size, this is will The paper uses 8 GPUs and a much larger effective batch-size,
not converge to the 27.3 BLEU (uncased) SOTA. this is unlikely to hit the target bleu score regardless of
number of steps.
""" """
self._setup() self._setup()
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check. # Sets values directly to avoid validation check.
...@@ -177,14 +181,17 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark): ...@@ -177,14 +181,17 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
FLAGS.batch_size = 4096 * 2 FLAGS.batch_size = 4096 * 2
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000 FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_2_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook'] FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark() # These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(bleu_min=25.3, bleu_max=26)
def benchmark_graph_8_gpu(self): def benchmark_graph_8_gpu(self):
"""Benchmark graph mode 8 gpus. """Benchmark graph mode 8 gpus.
SOTA is 27.3 BLEU (uncased). Best so far is 27.2 with 4048 * 8 at 75,000 steps.
Other test: 2024 * 8 peaked at 26.66 at 100,000 steps.
""" """
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
...@@ -194,21 +201,48 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark): ...@@ -194,21 +201,48 @@ class TransformerBaseEstimatorAccuracy(EstimatorBenchmark):
FLAGS['bleu_source'].value = self.bleu_source FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 * 8 FLAGS.batch_size = 3072 * 8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000 FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook'] FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark() self._run_and_report_benchmark()
def _run_and_report_benchmark(self): def benchmark_graph_fp16_8_gpu(self):
"""benchmark 8 gpus with fp16 mixed precision.
SOTA is 27.3 BLEU (uncased).
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 3072 * 8
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu')
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def _run_and_report_benchmark(self, bleu_min=27.3, bleu_max=28):
"""Run benchmark and report results.
Args:
bleu_min: minimum expected uncased bleu. default is SOTA.
bleu_max: max expected uncased bleu. default is a high number.
"""
start_time_sec = time.time() start_time_sec = time.time()
stats = transformer_main.run_transformer(flags.FLAGS) stats = transformer_main.run_transformer(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
self._report_benchmark(stats, self._report_benchmark(stats,
wall_time_sec, wall_time_sec,
bleu_min=27.2, bleu_min=bleu_min,
bleu_max=28) bleu_max=bleu_max)
class TransformerBaseEstimatorBenchmark(EstimatorBenchmark): class TransformerBaseEstimatorBenchmark(EstimatorBenchmark):
...@@ -227,34 +261,70 @@ class TransformerBaseEstimatorBenchmark(EstimatorBenchmark): ...@@ -227,34 +261,70 @@ class TransformerBaseEstimatorBenchmark(EstimatorBenchmark):
"""Benchmark graph 1 gpu.""" """Benchmark graph 1 gpu."""
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.batch_size = 2048 FLAGS.batch_size = 4096
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_graph_fp16_1_gpu(self):
"""Benchmark graph fp16 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_1_gpu')
self._run_and_report_benchmark()
def benchmark_graph_2_gpu(self): def benchmark_graph_2_gpu(self):
"""Benchmark graph 2 gpus.""" """Benchmark graph 2 gpus."""
self._setup() self._setup()
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.batch_size = 2048 * 2 FLAGS.batch_size = 4096 * 2
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_graph_fp16_2_gpu(self):
"""Benchmark graph fp16 2 gpus."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096 * 2
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_2_gpu')
self._run_and_report_benchmark()
def benchmark_graph_4_gpu(self): def benchmark_graph_4_gpu(self):
"""Benchmark graph 4 gpus.""" """Benchmark graph 4 gpus."""
self._setup() self._setup()
FLAGS.num_gpus = 4 FLAGS.num_gpus = 4
FLAGS.batch_size = 2048 * 4 FLAGS.batch_size = 4096 * 4
FLAGS.model_dir = self._get_model_dir('benchmark_graph_4_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_4_gpu')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_graph_fp16_4_gpu(self):
"""Benchmark 4 graph fp16 gpus."""
self._setup()
FLAGS.num_gpus = 4
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096 * 4
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_4_gpu')
self._run_and_report_benchmark()
def benchmark_graph_8_gpu(self): def benchmark_graph_8_gpu(self):
"""Benchmark graph 8 gpus.""" """Benchmark graph 8 gpus."""
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.batch_size = 2048 * 8 FLAGS.batch_size = 4096 * 8
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_graph_fp16_8_gpu(self):
"""Benchmark graph fp16 8 gpus."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 4096 * 8
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu')
self._run_and_report_benchmark()
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = transformer_main.run_transformer(flags.FLAGS) stats = transformer_main.run_transformer(flags.FLAGS)
......
...@@ -182,6 +182,11 @@ def get_train_op_and_metrics(loss, params): ...@@ -182,6 +182,11 @@ def get_train_op_and_metrics(loss, params):
if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL: if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
# Uses automatic mixed precision FP16 training if on GPU.
if params["dtype"] == "fp16":
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# Calculate and apply gradients using LazyAdamOptimizer. # Calculate and apply gradients using LazyAdamOptimizer.
global_step = tf.train.get_global_step() global_step = tf.train.get_global_step()
tvars = tf.trainable_variables() tvars = tf.trainable_variables()
...@@ -232,8 +237,8 @@ def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file): ...@@ -232,8 +237,8 @@ def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file):
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref) estimator, subtokenizer, bleu_source, bleu_ref)
tf.logging.info("Bleu score (uncased): %d", uncased_score) tf.logging.info("Bleu score (uncased): %f", uncased_score)
tf.logging.info("Bleu score (cased): %d", cased_score) tf.logging.info("Bleu score (cased): %f", cased_score)
return uncased_score, cased_score return uncased_score, cased_score
...@@ -392,7 +397,7 @@ def define_transformer_flags(): ...@@ -392,7 +397,7 @@ def define_transformer_flags():
intra_op=False, intra_op=False,
synthetic_data=True, synthetic_data=True,
max_train_steps=False, max_train_steps=False,
dtype=False, dtype=True,
all_reduce_alg=True all_reduce_alg=True
) )
flags_core.define_benchmark() flags_core.define_benchmark()
......
...@@ -34,7 +34,7 @@ DTYPE_MAP = { ...@@ -34,7 +34,7 @@ DTYPE_MAP = {
def get_tf_dtype(flags_obj): def get_tf_dtype(flags_obj):
if getattr(flags_obj, 'fp16_implementation', None) == 'graph_rewrite': if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
# If the graph_rewrite is used, we build the graph with fp32, and let the # If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16. # graph rewrite change ops to fp16.
return tf.float32 return tf.float32
...@@ -55,7 +55,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -55,7 +55,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
tf_gpu_thread_mode=False, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
datasets_num_parallel_batches=False, datasets_num_parallel_batches=False,
dynamic_loss_scale=False, fp16_implementation=False): dynamic_loss_scale=False, fp16_implementation=False,
loss_scale=False):
"""Register flags for specifying performance tuning arguments. """Register flags for specifying performance tuning arguments.
Args: Args:
...@@ -76,6 +77,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -76,6 +77,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True. "dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag. fp16_implementation: Create fp16_implementation flag.
loss_scale: Controls the loss scaling, normally for mixed-precision
training. Can only be turned on if dtype is also True.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
...@@ -152,31 +155,33 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -152,31 +155,33 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
loss_scale_help_text = loss_scale_help_text.format( loss_scale_help_text = loss_scale_help_text.format(
"This must be an int/float", "") "This must be an int/float", "")
loss_scale_validation_msg = "loss_scale should be a positive int/float." loss_scale_validation_msg = "loss_scale should be a positive int/float."
flags.DEFINE_string( if loss_scale:
name="loss_scale", short_name="ls", default=None, flags.DEFINE_string(
help=help_wrap(loss_scale_help_text)) name="loss_scale", short_name="ls", default=None,
help=help_wrap(loss_scale_help_text))
@flags.validator(flag_name="loss_scale", message=loss_scale_validation_msg) @flags.validator(flag_name="loss_scale",
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable message=loss_scale_validation_msg)
"""Validator to check the loss scale flag is valid""" def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
if loss_scale is None: """Validator to check the loss scale flag is valid."""
return True # null case is handled in get_loss_scale() if loss_scale is None:
return True # null case is handled in get_loss_scale()
if loss_scale == "dynamic" and dynamic_loss_scale: if loss_scale == "dynamic" and dynamic_loss_scale:
return True return True
try: try:
loss_scale = float(loss_scale) loss_scale = float(loss_scale)
except ValueError: except ValueError:
return False return False
return loss_scale > 0 return loss_scale > 0
if fp16_implementation: if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model. # Currently, this flag is only defined for the estimator resnet model.
flags.DEFINE_enum( flags.DEFINE_enum(
name="fp16_implementation", default='casting', name="fp16_implementation", default="casting",
enum_values=('casting', 'graph_rewrite'), enum_values=("casting', 'graph_rewrite"),
help=help_wrap( help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no " "When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'casting' will cause manual tf.casts to " "impact on correctness. 'casting' will cause manual tf.casts to "
...@@ -184,19 +189,19 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -184,19 +189,19 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"tf.train.experimental.enable_mixed_precision_graph_rewrite will " "tf.train.experimental.enable_mixed_precision_graph_rewrite will "
"be used to automatically use fp16 without any manual casts.")) "be used to automatically use fp16 without any manual casts."))
@flags.multi_flags_validator(['fp16_implementation', 'dtype', @flags.multi_flags_validator(["fp16_implementation", "dtype",
'loss_scale']) "loss_scale"])
def _check_fp16_implementation(flags_dict): def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid.""" """Validator to check fp16_implementation flag is valid."""
if (flags_dict['fp16_implementation'] == 'graph_rewrite' and if (flags_dict["fp16_implementation"] == "graph_rewrite" and
flags_dict['dtype'] != 'fp16'): flags_dict["dtype"] != "fp16"):
raise flags.ValidationError('--fp16_implementation should not be ' raise flags.ValidationError("--fp16_implementation should not be "
'specified unless --dtype=fp16') "specified unless --dtype=fp16")
if (flags_dict['fp16_implementation'] != 'graph_rewrite' and if (flags_dict["fp16_implementation"] != "graph_rewrite" and
flags_dict['loss_scale'] == 'dynamic'): flags_dict["loss_scale"] == "dynamic"):
raise flags.ValidationError('--loss_scale=dynamic is only supported ' raise flags.ValidationError("--loss_scale=dynamic is only supported "
'when ' "when "
'--fp16_implementation=graph_rewrite') "--fp16_implementation=graph_rewrite")
return True return True
if all_reduce_alg: if all_reduce_alg:
......
...@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base(num_gpu=False) flags_core.define_base(num_gpu=False)
flags_core.define_performance(dynamic_loss_scale=True) flags_core.define_performance(dynamic_loss_scale=True, loss_scale=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment