Unverified Commit 2ed43e66 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Improve performance for Cifar ResNet benchmarks (#7178)

* Improve performance for Cifar ResNet benchmarks

* Revert batch size changes to benchmarks
parent 2fd76007
...@@ -115,7 +115,10 @@ def run(flags_obj): ...@@ -115,7 +115,10 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
...@@ -136,8 +139,12 @@ def run(flags_obj): ...@@ -136,8 +139,12 @@ def run(flags_obj):
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
eval_input_dataset = None
if not flags_obj.skip_eval:
eval_input_dataset = input_fn( eval_input_dataset = input_fn(
is_training=False, is_training=False,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
...@@ -151,8 +158,9 @@ def run(flags_obj): ...@@ -151,8 +158,9 @@ def run(flags_obj):
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
run_eagerly=flags_obj.run_eagerly, metrics=(['categorical_accuracy']
metrics=['categorical_accuracy']) if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train']) learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
...@@ -27,6 +27,7 @@ import functools ...@@ -27,6 +27,7 @@ import functools
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import regularizers
BATCH_NORM_DECAY = 0.997 BATCH_NORM_DECAY = 0.997
...@@ -56,44 +57,34 @@ def identity_building_block(input_tensor, ...@@ -56,44 +57,34 @@ def identity_building_block(input_tensor,
Output tensor for the block. Output tensor for the block.
""" """
filters1, filters2 = filters filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last': if backend.image_data_format() == 'channels_last':
bn_axis = 3 bn_axis = 3
else: else:
bn_axis = 1 bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size, x = layers.Conv2D(filters1, kernel_size,
padding='same', padding='same', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer= kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(
name=bn_name_base + '2a', axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
momentum=BATCH_NORM_DECAY, name=bn_name_base + '2a')(x, training=training)
epsilon=BATCH_NORM_EPSILON)( x = layers.Activation('relu')(x)
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size, x = layers.Conv2D(filters2, kernel_size,
padding='same', padding='same', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer= kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(
name=bn_name_base + '2b', axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
momentum=BATCH_NORM_DECAY, name=bn_name_base + '2b')(x, training=training)
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.add([x, input_tensor]) x = layers.add([x, input_tensor])
x = tf.keras.layers.Activation('relu')(x) x = layers.Activation('relu')(x)
return x return x
...@@ -132,48 +123,34 @@ def conv_building_block(input_tensor, ...@@ -132,48 +123,34 @@ def conv_building_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size, strides=strides, x = layers.Conv2D(filters1, kernel_size, strides=strides,
padding='same', padding='same', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer= kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(
name=bn_name_base + '2a', axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
momentum=BATCH_NORM_DECAY, name=bn_name_base + '2a')(x, training=training)
epsilon=BATCH_NORM_EPSILON)( x = layers.Activation('relu')(x)
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same', x = layers.Conv2D(filters2, kernel_size, padding='same', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer= kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(
name=bn_name_base + '2b', axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
momentum=BATCH_NORM_DECAY, name=bn_name_base + '2b')(x, training=training)
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
shortcut = tf.keras.layers.Conv2D(filters2, (1, 1), strides=strides, shortcut = layers.Conv2D(filters2, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer= kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(input_tensor)
shortcut = tf.keras.layers.BatchNormalization( shortcut = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1', axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)( name=bn_name_base + '1')(shortcut, training=training)
shortcut, training=training)
x = tf.keras.layers.add([x, shortcut]) x = layers.add([x, shortcut])
x = tf.keras.layers.Activation('relu')(x) x = layers.Activation('relu')(x)
return x return x
...@@ -210,6 +187,7 @@ def resnet_block(input_tensor, ...@@ -210,6 +187,7 @@ def resnet_block(input_tensor,
block='block_%d' % (i + 1), training=training) block='block_%d' % (i + 1), training=training)
return x return x
def resnet(num_blocks, classes=10, training=None): def resnet(num_blocks, classes=10, training=None):
"""Instantiates the ResNet architecture. """Instantiates the ResNet architecture.
...@@ -239,21 +217,18 @@ def resnet(num_blocks, classes=10, training=None): ...@@ -239,21 +217,18 @@ def resnet(num_blocks, classes=10, training=None):
x = img_input x = img_input
bn_axis = 3 bn_axis = 3
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(16, (3, 3), x = layers.Conv2D(16, (3, 3),
strides=(1, 1), strides=(1, 1),
padding='valid', padding='valid', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer= kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x) name='conv1')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, name='bn_conv1', x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)( epsilon=BATCH_NORM_EPSILON,
x, training=training) name='bn_conv1',)(x, training=training)
x = tf.keras.layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16], x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],
stage=2, conv_strides=(1, 1), training=training) stage=2, conv_strides=(1, 1), training=training)
...@@ -264,13 +239,11 @@ def resnet(num_blocks, classes=10, training=None): ...@@ -264,13 +239,11 @@ def resnet(num_blocks, classes=10, training=None):
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64], x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],
stage=4, conv_strides=(2, 2), training=training) stage=4, conv_strides=(2, 2), training=training)
x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x) rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = tf.keras.layers.Dense(classes, activation='softmax', x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
kernel_initializer='he_normal', x = layers.Dense(classes, activation='softmax',
kernel_regularizer= # kernel_initializer='he_normal',
tf.keras.regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name='fc10')(x) name='fc10')(x)
inputs = img_input inputs = img_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment