"...git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "3f97baaefbc77088f50e13593bd49d385a981798"
Unverified Commit 1cdc35c8 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Handle data format and graph compile properly for better GPU performance (#6013)

* Handle data format in Keras ResNet model properly for better performance on GPU; Compile only the training graph when skip_eval flag is True

* Added data format fix to Keras Cifar model; Removed unnecessary import

* Add a comment to the skip_eval flag per Priya's request
parent 56cbd1f2
......@@ -108,6 +108,8 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT,
......@@ -160,6 +162,7 @@ def run(flags_obj):
validation_data = eval_input_dataset
if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
......
......@@ -95,6 +95,8 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format)
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
......@@ -149,6 +151,10 @@ def run(flags_obj):
validation_data = eval_input_dataset
if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
......
......@@ -187,16 +187,18 @@ def resnet56(classes=100, training=None):
Returns:
A Keras model instance.
"""
# Determine proper input shape
input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32)
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else: # channel_last
input_shape = (32, 32, 3)
x = img_input
bn_axis = 3
img_input = layers.Input(shape=input_shape)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid',
......
......@@ -190,16 +190,18 @@ def resnet50(num_classes):
Returns:
A Keras model instance.
"""
# Determine proper input shape
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
input_shape = (3, 224, 224)
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else:
input_shape = (224, 224, 3)
else: # channels_last
x = img_input
bn_axis = 3
img_input = layers.Input(shape=input_shape)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment