Unverified Commit 58deb059 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Enable CuDNN BatchNorm spatial persistent by default (#6710)

* Enable CuDNN BatchNorm spatial persistent by default; Remove 2nd zero padding layer

* Apply scale=False and fused=True consistently to BatchNorm layers

* Undo remove padding layer

* Replace zero padding with padding attribute in max pooling for better performance

* Resolve comments

* Revert "Replace zero padding with padding attribute in max pooling for better performance"

This reverts commit ad49db057c800ecac008eec1057005bd2c08ac73.
parent 0a96c7b4
...@@ -291,9 +291,13 @@ def define_keras_flags(): ...@@ -291,9 +291,13 @@ def define_keras_flags():
'help improve performance using EagerIterator and function. The codepath ' 'help improve performance using EagerIterator and function. The codepath '
'when enabling this feature is experimental and will be removed once the ' 'when enabling this feature is experimental and will be removed once the '
'corresponding performance features are fully supported in TensorFlow.') 'corresponding performance features are fully supported in TensorFlow.')
flags.DEFINE_boolean(name='clone_model_in_keras_dist_strat', default=True, flags.DEFINE_boolean(
help='If False, then the experimental code path is used' name='batchnorm_spatial_persistent', default=True,
' that doesn\'t clone models for distribution.') help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
flags.DEFINE_boolean(
name='clone_model_in_keras_dist_strat', default=True,
help='If False, then the experimental code path is used that doesn\'t '
'clone models for distribution.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
...@@ -358,6 +362,15 @@ def data_prefetch_with_slack(): ...@@ -358,6 +362,15 @@ def data_prefetch_with_slack():
_monkey_patch_org_create_device_dataset() _monkey_patch_org_create_device_dataset()
def set_cudnn_batchnorm_mode():
"""Set CuDNN batchnorm mode for better performance. Note that the spatial
persistent mode may lead to accuracy losses for certain models."""
if FLAGS.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
else:
del os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT']
def _monkey_patch_org_assert_broadcastable(): def _monkey_patch_org_assert_broadcastable():
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA.""" """Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
def no_op_assert_broadcastable(weights, values): def no_op_assert_broadcastable(weights, values):
......
...@@ -109,6 +109,7 @@ def run(flags_obj): ...@@ -109,6 +109,7 @@ def run(flags_obj):
keras_common.set_gpu_thread_mode_and_count(flags_obj) keras_common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_prefetch_with_slack: if flags_obj.data_prefetch_with_slack:
keras_common.data_prefetch_with_slack() keras_common.data_prefetch_with_slack()
keras_common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16': if dtype == 'float16':
......
...@@ -69,6 +69,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -69,6 +69,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x) name=bn_name_base + '2a')(x)
...@@ -80,6 +81,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -80,6 +81,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
...@@ -90,6 +92,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -90,6 +92,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(x)
...@@ -136,6 +139,7 @@ def conv_block(input_tensor, ...@@ -136,6 +139,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x) name=bn_name_base + '2a')(x)
...@@ -146,6 +150,7 @@ def conv_block(input_tensor, ...@@ -146,6 +150,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(x)
...@@ -156,6 +161,7 @@ def conv_block(input_tensor, ...@@ -156,6 +161,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(x)
...@@ -165,6 +171,7 @@ def conv_block(input_tensor, ...@@ -165,6 +171,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis, shortcut = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut) name=bn_name_base + '1')(shortcut)
...@@ -204,6 +211,7 @@ def resnet50(num_classes, dtype='float32', batch_size=None): ...@@ -204,6 +211,7 @@ def resnet50(num_classes, dtype='float32', batch_size=None):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x) name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis, x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x) name='bn_conv1')(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment