Unverified Commit 2ed43e66 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Improve performance for Cifar ResNet benchmarks (#7178)

* Improve performance for Cifar ResNet benchmarks

* Revert batch size changes to benchmarks
parent 2fd76007
......@@ -115,7 +115,10 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
......@@ -136,8 +139,12 @@ def run(flags_obj):
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
eval_input_dataset = None
if not flags_obj.skip_eval:
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
......@@ -151,8 +158,9 @@ def run(flags_obj):
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
run_eagerly=flags_obj.run_eagerly,
metrics=['categorical_accuracy'])
metrics=(['categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
......@@ -27,6 +27,7 @@ import functools
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras import regularizers
BATCH_NORM_DECAY = 0.997
......@@ -56,44 +57,34 @@ def identity_building_block(input_tensor,
Output tensor for the block.
"""
filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last':
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size,
padding='same',
x = layers.Conv2D(filters1, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2a',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x, training=training)
x = layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size,
padding='same',
x = layers.Conv2D(filters2, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2b',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x, training=training)
x = tf.keras.layers.add([x, input_tensor])
x = tf.keras.layers.Activation('relu')(x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
return x
......@@ -132,48 +123,34 @@ def conv_building_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size, strides=strides,
padding='same',
x = layers.Conv2D(filters1, kernel_size, strides=strides,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2a',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x, training=training)
x = layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same',
x = layers.Conv2D(filters2, kernel_size, padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2b',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x, training=training)
shortcut = tf.keras.layers.Conv2D(filters2, (1, 1), strides=strides,
shortcut = layers.Conv2D(filters2, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = tf.keras.layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1',
momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(
shortcut, training=training)
shortcut = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut, training=training)
x = tf.keras.layers.add([x, shortcut])
x = tf.keras.layers.Activation('relu')(x)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
......@@ -210,6 +187,7 @@ def resnet_block(input_tensor,
block='block_%d' % (i + 1), training=training)
return x
def resnet(num_blocks, classes=10, training=None):
"""Instantiates the ResNet architecture.
......@@ -239,21 +217,18 @@ def resnet(num_blocks, classes=10, training=None):
x = img_input
bn_axis = 3
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(16, (3, 3),
x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid',
padding='valid', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, name='bn_conv1',
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1',)(x, training=training)
x = layers.Activation('relu')(x)
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],
stage=2, conv_strides=(1, 1), training=training)
......@@ -264,13 +239,11 @@ def resnet(num_blocks, classes=10, training=None):
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],
stage=4, conv_strides=(2, 2), training=training)
x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = tf.keras.layers.Dense(classes, activation='softmax',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(classes, activation='softmax',
# kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc10')(x)
inputs = img_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment