"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "b98409cb5fc09d51e8aaee7d294efd23bdf47b58"
Commit 8f526987 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

using existing flag --fp16_implementation in...

using existing flag --fp16_implementation in official/utils/flags/_performance.py to enable automatic mixed precision
parent 6c965160
...@@ -90,7 +90,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -90,7 +90,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp32' FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.automatic_mixed_precision = True FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True FLAGS.use_tensor_lr = True
...@@ -326,7 +326,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -326,7 +326,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.automatic_mixed_precision = True FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 256 FLAGS.batch_size = 256
...@@ -350,7 +350,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -350,7 +350,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.automatic_mixed_precision = True FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
...@@ -505,7 +505,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -505,7 +505,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.automatic_mixed_precision = True FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
...@@ -542,7 +542,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -542,7 +542,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.automatic_mixed_precision = True FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
......
...@@ -253,6 +253,7 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -253,6 +253,7 @@ def define_keras_flags(dynamic_loss_scale=True):
datasets_num_private_threads=True, datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True, loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True, tf_data_experimental_slack=True,
enable_xla=True, enable_xla=True,
force_v2_in_keras_compile=True) force_v2_in_keras_compile=True)
...@@ -307,9 +308,6 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -307,9 +308,6 @@ def define_keras_flags(dynamic_loss_scale=True):
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_get_next_as_optional', default=False, name='enable_get_next_as_optional', default=False,
help='Enable get_next_as_optional behavior in DistributedIterator.') help='Enable get_next_as_optional behavior in DistributedIterator.')
flags.DEFINE_boolean(
name='automatic_mixed_precision', default=False,
help='Enable automatic mixed precision training via a graph rewrite.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True): dtype=tf.float32, drop_remainder=True):
......
...@@ -96,8 +96,11 @@ def run(flags_obj): ...@@ -96,8 +96,11 @@ def run(flags_obj):
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16': if dtype == 'float16':
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') # Mixed precision training via graph rewrite should not be used in conjunction
tf.keras.mixed_precision.experimental.set_policy(policy) # with tf.keras.mixed_precision
if flags_obj["fp16_implementation"] != "graph_rewrite":
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
...@@ -182,15 +185,13 @@ def run(flags_obj): ...@@ -182,15 +185,13 @@ def run(flags_obj):
if dtype == 'float16': if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code. # can be enabled with a single line of code.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( if flags_dict["fp16_implementation"] == "graph_rewrite":
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj, optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
default_for_fp16=128)) else:
if flags_obj.automatic_mixed_precision: optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
if dtype == 'float16': optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
raise RuntimeError("Automatic mixed precision should not be called in conjunction with " default_for_fp16=128))
"other types of mixed precision training. Set --dtype=fp32 instead.")
optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model( model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES, dtype) imagenet_preprocessing.NUM_CLASSES, dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment