Commit 63d84bff authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

adding automatic mixed precision training support to Resnet

parent 58340818
......@@ -310,6 +310,9 @@ def define_keras_flags(dynamic_loss_scale=True):
flags.DEFINE_boolean(
name='enable_get_next_as_optional', default=False,
help='Enable get_next_as_optional behavior in DistributedIterator.')
flags.DEFINE_boolean(
name='automatic_mixed_precision', default=False,
help='Enable automatic mixed precision training via a graph rewrite.')
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True):
......
......@@ -181,7 +181,12 @@ def run(flags_obj):
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128))
if flags_obj.automatic_mixed_precision:
if dtype == 'float16':
raise RuntimeError("Automatic mixed precision should not be called in conjunction with "
"other types of mixed precision training. Set --dtype=fp32 instead.")
optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES, dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment