Commit fcb8590b authored by Kaixi Hou's avatar Kaixi Hou
Browse files

minor changes

parent 8fe2729e
...@@ -140,9 +140,7 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -140,9 +140,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self): def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, dist_strat and 8 GPUs with automatic mixed """Test Keras model with eager, 8 GPUs with automatic mixed precision."""
precision.
"""
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
...@@ -152,7 +150,6 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -152,7 +150,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -234,7 +231,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -234,7 +231,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self): def benchmark_xla_1_gpu_amp(self):
...@@ -247,7 +243,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -247,7 +243,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True FLAGS.enable_xla = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -283,7 +278,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -283,7 +278,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self): def benchmark_xla_8_gpu_amp(self):
...@@ -296,7 +290,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -296,7 +290,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True FLAGS.enable_xla = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
......
...@@ -174,10 +174,10 @@ def run(flags_obj): ...@@ -174,10 +174,10 @@ def run(flags_obj):
if flags_obj.fp16_implementation == "graph_rewrite": if flags_obj.fp16_implementation == "graph_rewrite":
if not flags_obj.use_tf_function: if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires " raise ValueError("--fp16_implementation=graph_rewrite requires "
"use_tf_function to be true") "--use_tf_function to be true")
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32) 'training_accuracy', dtype=tf.float32)
...@@ -213,16 +213,14 @@ def run(flags_obj): ...@@ -213,16 +213,14 @@ def run(flags_obj):
loss += (tf.reduce_sum(model.losses) / num_replicas) loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss # Scale the loss
if flags_obj.fp16_implementation == "graph_rewrite": if flags_obj.dtype == "fp16":
loss = loss * tf.cast(loss_scale, loss.dtype) loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, trainable_variables) grads = tape.gradient(loss, trainable_variables)
# Unscale the grads # Unscale the grads
if flags_obj.fp16_implementation == "graph_rewrite": if flags_obj.dtype == "fp16":
loss_scale_reciprocal = 1. / loss_scale grads = optimizer.get_unscaled_gradients(grads)
grads = [g * tf.cast(loss_scale_reciprocal, g.dtype) if g is not None
else None for g in grads]
optimizer.apply_gradients(zip(grads, trainable_variables)) optimizer.apply_gradients(zip(grads, trainable_variables))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment