Unverified Commit 1cdc35c8 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Handle data format and graph compile properly for better GPU performance (#6013)

* Handle data format in Keras ResNet model properly for better performance on GPU; Compile only the training graph when skip_eval flag is True

* Added data format fix to Keras Cifar model; Removed unnecessary import

* Add a comment to the skip_eval flag per Priya's request
parent 56cbd1f2
......@@ -108,6 +108,8 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT,
......@@ -160,6 +162,7 @@ def run(flags_obj):
validation_data = eval_input_dataset
if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
......
......@@ -95,6 +95,8 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format)
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
......@@ -149,6 +151,10 @@ def run(flags_obj):
validation_data = eval_input_dataset
if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
......
......@@ -187,16 +187,18 @@ def resnet56(classes=100, training=None):
Returns:
A Keras model instance.
"""
# Determine proper input shape
input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32)
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else: # channel_last
input_shape = (32, 32, 3)
x = img_input
bn_axis = 3
img_input = layers.Input(shape=input_shape)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid',
......
......@@ -190,16 +190,18 @@ def resnet50(num_classes):
Returns:
A Keras model instance.
"""
# Determine proper input shape
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
input_shape = (3, 224, 224)
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else:
input_shape = (224, 224, 3)
else: # channels_last
x = img_input
bn_axis = 3
img_input = layers.Input(shape=input_shape)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment