Commit a182abc1 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by Toby Boyd
Browse files

Fix ResNet model convergence problem (#6721)

parent 58deb059
...@@ -368,7 +368,7 @@ def set_cudnn_batchnorm_mode(): ...@@ -368,7 +368,7 @@ def set_cudnn_batchnorm_mode():
if FLAGS.batchnorm_spatial_persistent: if FLAGS.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1' os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
else: else:
del os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] os.environ.pop('TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT', None)
def _monkey_patch_org_assert_broadcastable(): def _monkey_patch_org_assert_broadcastable():
......
...@@ -69,7 +69,6 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -69,7 +69,6 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x) name=bn_name_base + '2a')(x)
...@@ -81,7 +80,6 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -81,7 +80,6 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
...@@ -92,7 +90,6 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -92,7 +90,6 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(x)
...@@ -139,7 +136,6 @@ def conv_block(input_tensor, ...@@ -139,7 +136,6 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x) name=bn_name_base + '2a')(x)
...@@ -150,7 +146,6 @@ def conv_block(input_tensor, ...@@ -150,7 +146,6 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
...@@ -161,7 +156,6 @@ def conv_block(input_tensor, ...@@ -161,7 +156,6 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(x)
...@@ -171,7 +165,6 @@ def conv_block(input_tensor, ...@@ -171,7 +165,6 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis, shortcut = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut) name=bn_name_base + '1')(shortcut)
...@@ -211,13 +204,11 @@ def resnet50(num_classes, dtype='float32', batch_size=None): ...@@ -211,13 +204,11 @@ def resnet50(num_classes, dtype='float32', batch_size=None):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x) name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x) name='bn_conv1')(x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment