"vscode:/vscode.git/clone" did not exist on "86cf154be928315cc33234b87bbeca6bbc006e68"
Unverified Commit 1cdc35c8 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Handle data format and graph compile properly for better GPU performance (#6013)

* Handle data format in Keras ResNet model properly for better performance on GPU; Compile only the training graph when skip_eval flag is True

* Added data format fix to Keras Cifar model; Removed unnecessary import

* Add a comment to the skip_eval flag per Priya's request
parent 56cbd1f2
...@@ -108,6 +108,8 @@ def run(flags_obj): ...@@ -108,6 +108,8 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT, height=cifar_main.HEIGHT,
...@@ -160,6 +162,7 @@ def run(flags_obj): ...@@ -160,6 +162,7 @@ def run(flags_obj):
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
......
...@@ -95,6 +95,8 @@ def run(flags_obj): ...@@ -95,6 +95,8 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).') 'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format)
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
...@@ -149,6 +151,10 @@ def run(flags_obj): ...@@ -149,6 +151,10 @@ def run(flags_obj):
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
......
...@@ -187,16 +187,18 @@ def resnet56(classes=100, training=None): ...@@ -187,16 +187,18 @@ def resnet56(classes=100, training=None):
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
# Determine proper input shape input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32) x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1 bn_axis = 1
else: # channel_last else: # channel_last
input_shape = (32, 32, 3) x = img_input
bn_axis = 3 bn_axis = 3
img_input = layers.Input(shape=input_shape) x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.Conv2D(16, (3, 3), x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1), strides=(1, 1),
padding='valid', padding='valid',
......
...@@ -190,16 +190,18 @@ def resnet50(num_classes): ...@@ -190,16 +190,18 @@ def resnet50(num_classes):
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
# Determine proper input shape input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
input_shape = (3, 224, 224) x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1 bn_axis = 1
else: else: # channels_last
input_shape = (224, 224, 3) x = img_input
bn_axis = 3 bn_axis = 3
img_input = layers.Input(shape=input_shape) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(64, (7, 7),
strides=(2, 2), strides=(2, 2),
padding='valid', padding='valid',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment