Commit 8fe2729e authored by Kaixi Hou's avatar Kaixi Hou
Browse files

Enable graph rewrite for resnet 50 ctl

Add checks on if use_tf_function is true

benchmarks for the ctl with AMP
parent a009f4fb
......@@ -139,6 +139,24 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, dist_strat and 8 GPUs with automatic mixed
precision.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = ctl_imagenet_main.run(flags.FLAGS)
......@@ -206,6 +224,33 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_1_gpu_eager(self):
"""Test Keras model with 1 GPU in pure eager mode."""
self._setup()
......@@ -228,6 +273,33 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def fill_report_object(self, stats):
super(Resnet50CtlBenchmarkBase, self).fill_report_object(
stats,
......
......@@ -171,6 +171,14 @@ def run(flags_obj):
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True)
if flags_obj.fp16_implementation == "graph_rewrite":
if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires "
"use_tf_function to be true")
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
......@@ -203,7 +211,19 @@ def run(flags_obj):
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.fp16_implementation == "graph_rewrite":
loss = loss * tf.cast(loss_scale, loss.dtype)
grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
if flags_obj.fp16_implementation == "graph_rewrite":
loss_scale_reciprocal = 1. / loss_scale
grads = [g * tf.cast(loss_scale_reciprocal, g.dtype) if g is not None
else None for g in grads]
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits)
......@@ -296,6 +316,5 @@ if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
common.define_keras_flags()
ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common)
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment