Commit 8a3fecc0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7419 from vinhngx:amp_resnet50

PiperOrigin-RevId: 265939281
parents 82ba6174 a35e09d2
......@@ -78,7 +78,24 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, dist_strat and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.enable_eager = True
FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Test Keras model with eager, dist_strat, 8 GPUs, and fp16."""
self._setup()
......@@ -303,6 +320,18 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 256
self._run_and_report_benchmark()
def benchmark_xla_1_gpu(self):
"""Test Keras model with XLA and 1 GPU."""
......@@ -316,6 +345,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 256
self._run_and_report_benchmark()
def benchmark_1_gpu_fp16(self):
"""Test Keras model with 1 GPU and fp16."""
self._setup()
......@@ -458,6 +501,19 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_8_gpu_tweaked(self):
"""Test Keras model with manual config tuning and 8 GPUs."""
self._setup()
......@@ -483,6 +539,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_tweaked(self):
"""Test Keras model with manual config tuning, 8 GPUs, and XLA."""
self._setup()
......
......@@ -72,7 +72,8 @@ def define_transformer_flags():
loss_scale=True,
all_reduce_alg=True,
enable_xla=True,
force_v2_in_keras_compile=True
force_v2_in_keras_compile=True,
fp16_implementation=True
)
# Additional performance flags
......@@ -85,7 +86,7 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
......
......@@ -345,6 +345,30 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
bleu_min=28,
bleu_max=29.2)
def benchmark_8_gpu_fp16_amp(self):
"""Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16.
......
......@@ -419,6 +419,18 @@ class TransformerTask(object):
params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt
......
......@@ -190,7 +190,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
return loss_scale > 0
if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model.
# Currently, this flag is only defined for the estimator resnet and transformer models.
flags.DEFINE_enum(
name="fp16_implementation", default="casting",
enum_values=("casting', 'graph_rewrite"),
......
......@@ -257,6 +257,7 @@ def define_keras_flags(dynamic_loss_scale=True):
datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True,
enable_xla=True,
force_v2_in_keras_compile=True)
......
......@@ -33,7 +33,6 @@ from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model
from official.benchmark.models import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
......@@ -186,6 +185,13 @@ def run(flags_obj):
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128))
if flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment