"git@developer.sourcefind.cn:OpenDAS/d2go.git" did not exist on "5bf4cc7d39f242b0c002c61431b2ba68552bce68"
Commit ffa522ea authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Support fp16 using tf.keras.mixed_precision in CTL resnet.

To test, I ran the following command:

python resnet_ctl_imagenet_main.py --batch_size=2048 --data_dir ~/imagenet --datasets_num_private_threads=14 --epochs_between_evals=10 --model_dir ~/tmp_model_dir --clean --num_gpus=8 --train_epochs=90 --dtype=fp16

I got 76.15% final evaluation accuracy.

PiperOrigin-RevId: 278010061
parent 9df6a3d6
...@@ -140,8 +140,22 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -140,8 +140,22 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Test Keras model with eager, 8 GPUs with tf.keras mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 256 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
FLAGS.dtype = 'fp16'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self): def benchmark_8_gpu_amp(self):
"""Test Keras model with eager, 8 GPUs with automatic mixed precision.""" """Test Keras model with 8 GPUs and mixed precision via graph rewrite."""
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
......
...@@ -188,7 +188,11 @@ def run(flags_obj): ...@@ -188,7 +188,11 @@ def run(flags_obj):
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.bfloat16: if dtype == tf.float16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_float16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16') 'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
...@@ -235,7 +239,13 @@ def run(flags_obj): ...@@ -235,7 +239,13 @@ def run(flags_obj):
compute_lr_on_cpu=True) compute_lr_on_cpu=True)
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
if flags_obj.fp16_implementation == 'graph_rewrite': if dtype == tf.float16:
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale)
elif flags_obj.fp16_implementation == 'graph_rewrite':
# `dtype` is still float32 in this case. We built the graph in float32 and
# let the graph rewrite change parts of it float16.
if not flags_obj.use_tf_function: if not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires ' raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true') '--use_tf_function to be true')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment