Commit e4a046e7 authored by Reed's avatar Reed Committed by Toby Boyd
Browse files

Mixed precision support (#6309)

* Mixed precision support

* Add TODOs
parent 8367cf6d
...@@ -102,9 +102,9 @@ def run(flags_obj): ...@@ -102,9 +102,9 @@ def run(flags_obj):
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready. # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'float16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
'value(fp32).') tf.keras.mixed_precision.experimental.set_policy(policy)
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
...@@ -120,7 +120,7 @@ def run(flags_obj): ...@@ -120,7 +120,7 @@ def run(flags_obj):
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=dtype)
else: else:
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
...@@ -131,14 +131,16 @@ def run(flags_obj): ...@@ -131,14 +131,16 @@ def run(flags_obj):
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads) datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
eval_input_dataset = input_fn( eval_input_dataset = input_fn(
is_training=False, is_training=False,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras,
dtype=dtype)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
...@@ -148,7 +150,13 @@ def run(flags_obj): ...@@ -148,7 +150,13 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
model.compile(loss='sparse_categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
......
...@@ -174,7 +174,7 @@ def conv_block(input_tensor, ...@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return x return x
def resnet50(num_classes): def resnet50(num_classes, dtype='float32'):
# TODO(tfboyd): add training argument, just lik resnet56. # TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
...@@ -185,7 +185,7 @@ def resnet50(num_classes): ...@@ -185,7 +185,7 @@ def resnet50(num_classes):
A Keras model instance. A Keras model instance.
""" """
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape) img_input = layers.Input(shape=input_shape, dtype=dtype)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
...@@ -232,10 +232,14 @@ def resnet50(num_classes): ...@@ -232,10 +232,14 @@ def resnet50(num_classes):
x = layers.GlobalAveragePooling2D(name='avg_pool')(x) x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense( x = layers.Dense(
num_classes, activation='softmax', num_classes,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY), bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x) name='fc1000')(x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
x = backend.cast(x, 'float32')
x = layers.Activation('softmax')(x)
# Create model. # Create model.
return models.Model(img_input, x, name='resnet50') return models.Model(img_input, x, name='resnet50')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment