Commit ad730d97 authored by Derek Chow's avatar Derek Chow
Browse files

Update slim/nets.

- Set spatial squeeze to true for resnets
- Make activations configurable for resnets
- Make batch norm configurable for resnets
- Fix some bad comment wrapping in vgg.py.
parent beeae099
......@@ -199,7 +199,9 @@ def stack_blocks_dense(net, blocks, output_stride=None,
def resnet_arg_scope(weight_decay=0.0001,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True):
batch_norm_scale=True,
activation_fn=tf.nn.relu,
use_batch_norm=True):
"""Defines the default ResNet arg scope.
TODO(gpapan): The batch-normalization related default values above are
......@@ -215,6 +217,8 @@ def resnet_arg_scope(weight_decay=0.0001,
normalizing activations by their variance in batch normalization.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
activation_fn: The activation function which is used in ResNet.
use_batch_norm: Whether or not to use batch normalization.
Returns:
An `arg_scope` to use for the resnet models.
......@@ -230,8 +234,8 @@ def resnet_arg_scope(weight_decay=0.0001,
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
activation_fn=activation_fn,
normalizer_fn=slim.batch_norm if use_batch_norm else None,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
# The following implies padding='SAME' for pool1, which makes feature
......
......@@ -66,8 +66,14 @@ slim = tf.contrib.slim
@slim.add_arg_scope
def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
outputs_collections=None, scope=None):
def bottleneck(inputs,
depth,
depth_bottleneck,
stride,
rate=1,
outputs_collections=None,
scope=None,
use_bounded_activations=False):
"""Bottleneck residual unit variant with BN after convolutions.
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
......@@ -86,6 +92,8 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
rate: An integer, rate for atrous convolution.
outputs_collections: Collection to add the ResNet unit output.
scope: Optional variable_scope.
use_bounded_activations: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
Returns:
The ResNet unit's output.
......@@ -95,8 +103,12 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
if depth == depth_in:
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
else:
shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride,
activation_fn=None, scope='shortcut')
shortcut = slim.conv2d(
inputs,
depth, [1, 1],
stride=stride,
activation_fn=tf.nn.relu6 if use_bounded_activations else None,
scope='shortcut')
residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
scope='conv1')
......@@ -105,6 +117,11 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
residual = slim.conv2d(residual, depth, [1, 1], stride=1,
activation_fn=None, scope='conv3')
if use_bounded_activations:
# Use clip_by_value to simulate bandpass activation.
residual = tf.clip_by_value(residual, -6.0, 6.0)
output = tf.nn.relu6(shortcut + residual)
else:
output = tf.nn.relu(shortcut + residual)
return slim.utils.collect_named_outputs(outputs_collections,
......@@ -119,7 +136,7 @@ def resnet_v1(inputs,
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=False,
spatial_squeeze=True,
reuse=None,
scope=None):
"""Generator for v1 ResNet models.
......
......@@ -251,6 +251,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_small'):
"""A shallow and thin ResNet v1 for faster tests."""
......@@ -266,6 +267,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=global_pool,
output_stride=output_stride,
include_root_block=include_root_block,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)
......@@ -276,6 +278,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
spatial_squeeze=False,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
......@@ -307,6 +310,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
spatial_squeeze=False,
scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4],
......@@ -325,6 +329,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
_, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
include_root_block=False,
spatial_squeeze=False,
scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 64, 64, 4],
......@@ -345,6 +350,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
num_classes,
global_pool=global_pool,
output_stride=output_stride,
spatial_squeeze=False,
scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4],
......@@ -391,6 +397,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
spatial_squeeze=False,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
......
......@@ -115,7 +115,7 @@ def resnet_v2(inputs,
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=False,
spatial_squeeze=True,
reuse=None,
scope=None):
"""Generator for v2 (preactivation) ResNet models.
......@@ -251,7 +251,7 @@ def resnet_v2_50(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=False,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_50'):
"""ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
......@@ -273,7 +273,7 @@ def resnet_v2_101(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=False,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_101'):
"""ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
......@@ -295,7 +295,7 @@ def resnet_v2_152(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=False,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_152'):
"""ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
......@@ -317,7 +317,7 @@ def resnet_v2_200(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=False,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_200'):
"""ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
......
......@@ -251,6 +251,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_small'):
"""A shallow and thin ResNet v2 for faster tests."""
......@@ -266,6 +267,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=global_pool,
output_stride=output_stride,
include_root_block=include_root_block,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)
......@@ -276,6 +278,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
spatial_squeeze=False,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
......@@ -307,6 +310,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
spatial_squeeze=False,
scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4],
......@@ -325,6 +329,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
_, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
include_root_block=False,
spatial_squeeze=False,
scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 64, 64, 4],
......@@ -345,6 +350,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
num_classes,
global_pool=global_pool,
output_stride=output_stride,
spatial_squeeze=False,
scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4],
......@@ -393,6 +399,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(inputs, num_classes,
global_pool=global_pool,
spatial_squeeze=False,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
......
......@@ -87,8 +87,9 @@ def vgg_a(inputs,
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output. Otherwise,
the output prediction map will be (input / 32) - 6 in case of 'VALID' padding.
get a prediction map downsampled by a factor of 32 as an output.
Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns:
the last op containing the log predictions and end_points dict.
......@@ -152,8 +153,9 @@ def vgg_16(inputs,
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output. Otherwise,
the output prediction map will be (input / 32) - 6 in case of 'VALID' padding.
get a prediction map downsampled by a factor of 32 as an output.
Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns:
the last op containing the log predictions and end_points dict.
......@@ -217,8 +219,10 @@ def vgg_19(inputs,
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output. Otherwise,
the output prediction map will be (input / 32) - 6 in case of 'VALID' padding.
get a prediction map downsampled by a factor of 32 as an output.
Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns:
the last op containing the log predictions and end_points dict.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment