Unverified Commit 178480ed authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Remove batch norm weight decay + a few other fixes (#2755)

parent 4c37264d
......@@ -184,10 +184,11 @@ def resnet_model_fn(features, labels, mode, params):
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# Add weight decay to the loss. We perform weight decay on all trainable
# variables, which includes batch norm beta and gamma variables.
# Add weight decay to the loss. We exclude the batch norm variables because
# doing so leads to a small improvement in accuracy.
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
[tf.nn.l2_loss(v) for v in tf.trainable_variables()
if 'batch_normalization' not in v.name])
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size
......
......@@ -242,8 +242,8 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
def model(inputs, is_training):
"""Constructs the ResNet model given the inputs."""
if data_format == 'channels_first':
# Convert from channels_last (NHWC) to channels_first (NCHW). This
# provides a large performance boost on GPU. See
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
......@@ -302,8 +302,9 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
def model(inputs, is_training):
"""Constructs the ResNet model given the inputs."""
if data_format == 'channels_first':
# Convert from channels_last (NHWC) to channels_first (NCHW). This
# provides a large performance boost on GPU.
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv2d_fixed_padding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment