Commit 8a3fecc0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7419 from vinhngx:amp_resnet50

PiperOrigin-RevId: 265939281
parents 82ba6174 a35e09d2
...@@ -79,6 +79,23 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -79,6 +79,23 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.use_tensor_lr = True FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, dist_strat and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.enable_eager = True
FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self): def benchmark_8_gpu_fp16(self):
"""Test Keras model with eager, dist_strat, 8 GPUs, and fp16.""" """Test Keras model with eager, dist_strat, 8 GPUs, and fp16."""
self._setup() self._setup()
...@@ -303,6 +320,18 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -303,6 +320,18 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 256
self._run_and_report_benchmark()
def benchmark_xla_1_gpu(self): def benchmark_xla_1_gpu(self):
"""Test Keras model with XLA and 1 GPU.""" """Test Keras model with XLA and 1 GPU."""
...@@ -316,6 +345,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -316,6 +345,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 256
self._run_and_report_benchmark()
def benchmark_1_gpu_fp16(self): def benchmark_1_gpu_fp16(self):
"""Test Keras model with 1 GPU and fp16.""" """Test Keras model with 1 GPU and fp16."""
self._setup() self._setup()
...@@ -458,6 +501,19 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -458,6 +501,19 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_8_gpu_tweaked(self): def benchmark_8_gpu_tweaked(self):
"""Test Keras model with manual config tuning and 8 GPUs.""" """Test Keras model with manual config tuning and 8 GPUs."""
self._setup() self._setup()
...@@ -483,6 +539,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -483,6 +539,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = True
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_tweaked(self): def benchmark_xla_8_gpu_tweaked(self):
"""Test Keras model with manual config tuning, 8 GPUs, and XLA.""" """Test Keras model with manual config tuning, 8 GPUs, and XLA."""
self._setup() self._setup()
......
...@@ -72,7 +72,8 @@ def define_transformer_flags(): ...@@ -72,7 +72,8 @@ def define_transformer_flags():
loss_scale=True, loss_scale=True,
all_reduce_alg=True, all_reduce_alg=True,
enable_xla=True, enable_xla=True,
force_v2_in_keras_compile=True force_v2_in_keras_compile=True,
fp16_implementation=True
) )
# Additional performance flags # Additional performance flags
......
...@@ -345,6 +345,30 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -345,6 +345,30 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
bleu_min=28, bleu_min=28,
bleu_max=29.2) bleu_max=29.2)
def benchmark_8_gpu_fp16_amp(self):
"""Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_static_batch_fp16(self): def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16. """Benchmark 8 gpu with static batch and fp16.
......
...@@ -419,6 +419,18 @@ class TransformerTask(object): ...@@ -419,6 +419,18 @@ class TransformerTask(object):
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
...@@ -190,7 +190,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, ...@@ -190,7 +190,7 @@ def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
return loss_scale > 0 return loss_scale > 0
if fp16_implementation: if fp16_implementation:
# Currently, this flag is only defined for the estimator resnet model. # Currently, this flag is only defined for the estimator resnet and transformer models.
flags.DEFINE_enum( flags.DEFINE_enum(
name="fp16_implementation", default="casting", name="fp16_implementation", default="casting",
enum_values=("casting', 'graph_rewrite"), enum_values=("casting', 'graph_rewrite"),
......
...@@ -257,6 +257,7 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -257,6 +257,7 @@ def define_keras_flags(dynamic_loss_scale=True):
datasets_num_private_threads=True, datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True, loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True, tf_data_experimental_slack=True,
enable_xla=True, enable_xla=True,
force_v2_in_keras_compile=True) force_v2_in_keras_compile=True)
......
...@@ -33,7 +33,6 @@ from official.vision.image_classification import imagenet_preprocessing ...@@ -33,7 +33,6 @@ from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_model from official.vision.image_classification import resnet_model
from official.benchmark.models import trivial_model from official.benchmark.models import trivial_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
] ]
...@@ -186,6 +185,13 @@ def run(flags_obj): ...@@ -186,6 +185,13 @@ def run(flags_obj):
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128)) default_for_fp16=128))
if flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
# TODO(hongkuny): Remove trivial model usage and move it to benchmark. # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model( model = trivial_model.trivial_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment