Commit fcb8590b authored by Kaixi Hou's avatar Kaixi Hou
Browse files

minor changes

parent 8fe2729e
......@@ -140,9 +140,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, dist_strat and 8 GPUs with automatic mixed
precision.
"""
"""Test Keras model with eager, 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
......@@ -152,7 +150,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
......@@ -234,7 +231,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
......@@ -247,7 +243,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
......@@ -283,7 +278,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
......@@ -296,7 +290,6 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.use_tf_function = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
......
......@@ -174,10 +174,10 @@ def run(flags_obj):
if flags_obj.fp16_implementation == "graph_rewrite":
if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires "
"use_tf_function to be true")
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
"--use_tf_function to be true")
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
......@@ -213,16 +213,14 @@ def run(flags_obj):
loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.fp16_implementation == "graph_rewrite":
loss = loss * tf.cast(loss_scale, loss.dtype)
if flags_obj.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
if flags_obj.fp16_implementation == "graph_rewrite":
loss_scale_reciprocal = 1. / loss_scale
grads = [g * tf.cast(loss_scale_reciprocal, g.dtype) if g is not None
else None for g in grads]
if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, trainable_variables))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment