Unverified Commit fe1a9089 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Do not use bias in conv2d layers in Keras ResNet model (#6017)

* Set use_bias=False for conv layers in Keras ResNet model

* Removed bias regularizer from Conv2D layers (which should have no effect after bias is removed)

* Setting default data format based on available devices.
parent 1cdc35c8
......@@ -108,7 +108,11 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format)
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
......
......@@ -95,7 +95,11 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format)
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
......
......@@ -64,10 +64,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1),
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......@@ -76,10 +75,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same',
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......@@ -87,10 +85,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1),
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......@@ -134,9 +131,9 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), kernel_initializer='he_normal',
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......@@ -145,9 +142,8 @@ def conv_block(input_tensor,
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
kernel_initializer='he_normal',
use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......@@ -155,20 +151,18 @@ def conv_block(input_tensor,
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1),
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......@@ -204,10 +198,9 @@ def resnet50(num_classes):
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
padding='valid', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment