Commit 8fe2729e authored by Kaixi Hou's avatar Kaixi Hou
Browse files

Enable graph rewrite for resnet 50 ctl

Add checks on if use_tf_function is true

benchmarks for the ctl with AMP
parent a009f4fb
...@@ -139,6 +139,24 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -139,6 +139,24 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, dist_strat and 8 GPUs with automatic mixed
precision.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = ctl_imagenet_main.run(flags.FLAGS) stats = ctl_imagenet_main.run(flags.FLAGS)
...@@ -206,6 +224,33 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -206,6 +224,33 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_1_gpu_eager(self): def benchmark_1_gpu_eager(self):
"""Test Keras model with 1 GPU in pure eager mode.""" """Test Keras model with 1 GPU in pure eager mode."""
self._setup() self._setup()
...@@ -228,6 +273,33 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -228,6 +273,33 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50CtlBenchmarkBase, self).fill_report_object( super(Resnet50CtlBenchmarkBase, self).fill_report_object(
stats, stats,
......
...@@ -171,6 +171,14 @@ def run(flags_obj): ...@@ -171,6 +171,14 @@ def run(flags_obj):
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9, learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True) nesterov=True)
if flags_obj.fp16_implementation == "graph_rewrite":
if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires "
"use_tf_function to be true")
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32) 'training_accuracy', dtype=tf.float32)
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
...@@ -203,7 +211,19 @@ def run(flags_obj): ...@@ -203,7 +211,19 @@ def run(flags_obj):
loss += (l2_loss / num_replicas) loss += (l2_loss / num_replicas)
else: else:
loss += (tf.reduce_sum(model.losses) / num_replicas) loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.fp16_implementation == "graph_rewrite":
loss = loss * tf.cast(loss_scale, loss.dtype)
grads = tape.gradient(loss, trainable_variables) grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
if flags_obj.fp16_implementation == "graph_rewrite":
loss_scale_reciprocal = 1. / loss_scale
grads = [g * tf.cast(loss_scale_reciprocal, g.dtype) if g is not None
else None for g in grads]
optimizer.apply_gradients(zip(grads, trainable_variables)) optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits) training_accuracy.update_state(labels, logits)
...@@ -296,6 +316,5 @@ if __name__ == '__main__': ...@@ -296,6 +316,5 @@ if __name__ == '__main__':
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
common.define_keras_flags() common.define_keras_flags()
ctl_common.define_ctl_flags() ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common) flags.adopt_module_key_flags(ctl_common)
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment