Commit e4a046e7 authored by Reed's avatar Reed Committed by Toby Boyd
Browse files

Mixed precision support (#6309)

* Mixed precision support

* Add TODOs
parent 8367cf6d
......@@ -102,9 +102,9 @@ def run(flags_obj):
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
if dtype == 'float16':
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
tf.keras.mixed_precision.experimental.set_policy(policy)
data_format = flags_obj.data_format
if data_format is None:
......@@ -120,7 +120,7 @@ def run(flags_obj):
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
dtype=dtype)
else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn
......@@ -131,14 +131,16 @@ def run(flags_obj):
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
parse_record_fn=parse_record_keras,
dtype=dtype)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......@@ -148,7 +150,13 @@ def run(flags_obj):
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
......
......@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return x
def resnet50(num_classes):
def resnet50(num_classes, dtype='float32'):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
......@@ -185,7 +185,7 @@ def resnet50(num_classes):
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
img_input = layers.Input(shape=input_shape, dtype=dtype)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
......@@ -232,10 +232,14 @@ def resnet50(num_classes):
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(
num_classes, activation='softmax',
num_classes,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
x = backend.cast(x, 'float32')
x = layers.Activation('softmax')(x)
# Create model.
return models.Model(img_input, x, name='resnet50')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment