Unverified Commit 1cdc35c8 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Handle data format and graph compile properly for better GPU performance (#6013)

* Handle data format in Keras ResNet model properly for better performance on GPU; Compile only the training graph when skip_eval flag is True

* Added data format fix to Keras Cifar model; Removed unnecessary import

* Add a comment to the skip_eval flag per Priya's request
parent 56cbd1f2
...@@ -108,6 +108,8 @@ def run(flags_obj): ...@@ -108,6 +108,8 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT, height=cifar_main.HEIGHT,
...@@ -160,6 +162,7 @@ def run(flags_obj): ...@@ -160,6 +162,7 @@ def run(flags_obj):
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
......
...@@ -95,6 +95,8 @@ def run(flags_obj): ...@@ -95,6 +95,8 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).') 'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format)
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
...@@ -149,6 +151,10 @@ def run(flags_obj): ...@@ -149,6 +151,10 @@ def run(flags_obj):
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
......
...@@ -187,16 +187,18 @@ def resnet56(classes=100, training=None): ...@@ -187,16 +187,18 @@ def resnet56(classes=100, training=None):
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
# Determine proper input shape input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32) x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1 bn_axis = 1
else: # channel_last else: # channel_last
input_shape = (32, 32, 3) x = img_input
bn_axis = 3 bn_axis = 3
img_input = layers.Input(shape=input_shape) x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.Conv2D(16, (3, 3), x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1), strides=(1, 1),
padding='valid', padding='valid',
......
...@@ -190,16 +190,18 @@ def resnet50(num_classes): ...@@ -190,16 +190,18 @@ def resnet50(num_classes):
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
# Determine proper input shape input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
input_shape = (3, 224, 224) x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1 bn_axis = 1
else: else: # channels_last
input_shape = (224, 224, 3) x = img_input
bn_axis = 3 bn_axis = 3
img_input = layers.Input(shape=input_shape) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(64, (7, 7),
strides=(2, 2), strides=(2, 2),
padding='valid', padding='valid',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment