Unverified Commit fe1a9089 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Do not use bias in conv2d layers in Keras ResNet model (#6017)

* Set use_bias=False for conv layers in Keras ResNet model

* Removed bias regularizer from Conv2D layers (which should have no effect after bias is removed)

* Setting default data format based on available devices.
parent 1cdc35c8
...@@ -108,7 +108,11 @@ def run(flags_obj): ...@@ -108,7 +108,11 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format) data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
......
...@@ -95,7 +95,11 @@ def run(flags_obj): ...@@ -95,7 +95,11 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).') 'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format) data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
......
...@@ -64,10 +64,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -64,10 +64,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -76,10 +75,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -76,10 +75,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, x = layers.Conv2D(filters2, kernel_size,
padding='same', padding='same', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -87,10 +85,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -87,10 +85,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -134,9 +131,9 @@ def conv_block(input_tensor, ...@@ -134,9 +131,9 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), kernel_initializer='he_normal', x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -145,9 +142,8 @@ def conv_block(input_tensor, ...@@ -145,9 +142,8 @@ def conv_block(input_tensor,
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same', x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
kernel_initializer='he_normal', use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -155,20 +151,18 @@ def conv_block(input_tensor, ...@@ -155,20 +151,18 @@ def conv_block(input_tensor,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis, shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -204,10 +198,9 @@ def resnet50(num_classes): ...@@ -204,10 +198,9 @@ def resnet50(num_classes):
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(64, (7, 7),
strides=(2, 2), strides=(2, 2),
padding='valid', padding='valid', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x) name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment