"libai/models/utils/vscode:/vscode.git/clone" did not exist on "3b355d3f9d5f4f5501ff6e76ba4018d83b640087"
Unverified Commit 58deb059 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Enable CuDNN BatchNorm spatial persistent by default (#6710)

* Enable CuDNN BatchNorm spatial persistent by default; Remove 2nd zero padding layer

* Apply scale=False and fused=True consistently to BatchNorm layers

* Undo remove padding layer

* Replace zero padding with padding attribute in max pooling for better performance

* Resolve comments

* Revert "Replace zero padding with padding attribute in max pooling for better performance"

This reverts commit ad49db057c800ecac008eec1057005bd2c08ac73.
parent 0a96c7b4
......@@ -291,9 +291,13 @@ def define_keras_flags():
'help improve performance using EagerIterator and function. The codepath '
'when enabling this feature is experimental and will be removed once the '
'corresponding performance features are fully supported in TensorFlow.')
flags.DEFINE_boolean(name='clone_model_in_keras_dist_strat', default=True,
help='If False, then the experimental code path is used'
' that doesn\'t clone models for distribution.')
flags.DEFINE_boolean(
name='batchnorm_spatial_persistent', default=True,
help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
flags.DEFINE_boolean(
name='clone_model_in_keras_dist_strat', default=True,
help='If False, then the experimental code path is used that doesn\'t '
'clone models for distribution.')
def get_synth_input_fn(height, width, num_channels, num_classes,
......@@ -358,6 +362,15 @@ def data_prefetch_with_slack():
_monkey_patch_org_create_device_dataset()
def set_cudnn_batchnorm_mode():
"""Set CuDNN batchnorm mode for better performance. Note that the spatial
persistent mode may lead to accuracy losses for certain models."""
if FLAGS.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
else:
del os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT']
def _monkey_patch_org_assert_broadcastable():
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
def no_op_assert_broadcastable(weights, values):
......
......@@ -109,6 +109,7 @@ def run(flags_obj):
keras_common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_prefetch_with_slack:
keras_common.data_prefetch_with_slack()
keras_common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
......
......@@ -69,6 +69,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
......@@ -80,6 +81,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
......@@ -90,6 +92,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
......@@ -136,6 +139,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
......@@ -146,6 +150,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
......@@ -156,6 +161,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
......@@ -165,6 +171,7 @@ def conv_block(input_tensor,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut)
......@@ -204,6 +211,7 @@ def resnet50(num_classes, dtype='float32', batch_size=None):
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
scale=False,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment