"...pytorch/mvgrl/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "36c6c64957af1fbe7f38688557f6ba14f51d1e2d"
Unverified Commit 178480ed authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Remove batch norm weight decay + a few other fixes (#2755)

parent 4c37264d
...@@ -184,10 +184,11 @@ def resnet_model_fn(features, labels, mode, params): ...@@ -184,10 +184,11 @@ def resnet_model_fn(features, labels, mode, params):
tf.identity(cross_entropy, name='cross_entropy') tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy) tf.summary.scalar('cross_entropy', cross_entropy)
# Add weight decay to the loss. We perform weight decay on all trainable # Add weight decay to the loss. We exclude the batch norm variables because
# variables, which includes batch norm beta and gamma variables. # doing so leads to a small improvement in accuracy.
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()]) [tf.nn.l2_loss(v) for v in tf.trainable_variables()
if 'batch_normalization' not in v.name])
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size # Scale the learning rate linearly with the batch size. When the batch size
......
...@@ -242,8 +242,8 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None): ...@@ -242,8 +242,8 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
def model(inputs, is_training): def model(inputs, is_training):
"""Constructs the ResNet model given the inputs.""" """Constructs the ResNet model given the inputs."""
if data_format == 'channels_first': if data_format == 'channels_first':
# Convert from channels_last (NHWC) to channels_first (NCHW). This # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# provides a large performance boost on GPU. See # This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats # https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2]) inputs = tf.transpose(inputs, [0, 3, 1, 2])
...@@ -302,8 +302,9 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes, ...@@ -302,8 +302,9 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
def model(inputs, is_training): def model(inputs, is_training):
"""Constructs the ResNet model given the inputs.""" """Constructs the ResNet model given the inputs."""
if data_format == 'channels_first': if data_format == 'channels_first':
# Convert from channels_last (NHWC) to channels_first (NCHW). This # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# provides a large performance boost on GPU. # This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2]) inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv2d_fixed_padding( inputs = conv2d_fixed_padding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment