"docs/vscode:/vscode.git/clone" did not exist on "c5939616ddc17093747e896db6012b1f63792627"
Unverified Commit fe1a9089 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Do not use bias in conv2d layers in Keras ResNet model (#6017)

* Set use_bias=False for conv layers in Keras ResNet model

* Removed bias regularizer from Conv2D layers (which should have no effect after bias is removed)

* Setting default data format based on available devices.
parent 1cdc35c8
...@@ -108,7 +108,11 @@ def run(flags_obj): ...@@ -108,7 +108,11 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
tf.keras.backend.set_image_data_format(flags_obj.data_format) data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
......
...@@ -95,7 +95,11 @@ def run(flags_obj): ...@@ -95,7 +95,11 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).') 'value(fp32).')
tf.keras.backend.set_image_data_format(flags_obj.data_format) data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
......
...@@ -64,10 +64,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -64,10 +64,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -76,10 +75,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -76,10 +75,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, x = layers.Conv2D(filters2, kernel_size,
padding='same', padding='same', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -87,10 +85,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -87,10 +85,9 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -134,9 +131,9 @@ def conv_block(input_tensor, ...@@ -134,9 +131,9 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), kernel_initializer='he_normal', x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -145,9 +142,8 @@ def conv_block(input_tensor, ...@@ -145,9 +142,8 @@ def conv_block(input_tensor,
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same', x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
kernel_initializer='he_normal', use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -155,20 +151,18 @@ def conv_block(input_tensor, ...@@ -155,20 +151,18 @@ def conv_block(input_tensor,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis, shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
...@@ -204,10 +198,9 @@ def resnet50(num_classes): ...@@ -204,10 +198,9 @@ def resnet50(num_classes):
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(64, (7, 7),
strides=(2, 2), strides=(2, 2),
padding='valid', padding='valid', use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x) name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment