Commit 64f87cd2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7535 from houtoms:ctl_supports_amp

PiperOrigin-RevId: 267685527
parents a629af4c 78047d54
......@@ -139,6 +139,21 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = ctl_imagenet_main.run(flags.FLAGS)
......@@ -206,6 +221,31 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_1_gpu_eager(self):
"""Test Keras model with 1 GPU in pure eager mode."""
self._setup()
......@@ -228,6 +268,31 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def fill_report_object(self, stats):
super(Resnet50CtlBenchmarkBase, self).fill_report_object(
stats,
......
......@@ -171,6 +171,14 @@ def run(flags_obj):
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True)
if flags_obj.fp16_implementation == "graph_rewrite":
if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires "
"--use_tf_function to be true")
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
......@@ -203,7 +211,17 @@ def run(flags_obj):
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits)
......@@ -296,6 +314,5 @@ if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
common.define_keras_flags()
ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common)
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment