Commit ad730d97 authored by Derek Chow's avatar Derek Chow
Browse files

Update slim/nets.

- Set spatial squeeze to true for resnets
- Make activations configurable for resnets
- Make batch norm configurable for resnets
- Fix some bad comment wrapping in vgg.py.
parent beeae099
...@@ -199,7 +199,9 @@ def stack_blocks_dense(net, blocks, output_stride=None, ...@@ -199,7 +199,9 @@ def stack_blocks_dense(net, blocks, output_stride=None,
def resnet_arg_scope(weight_decay=0.0001, def resnet_arg_scope(weight_decay=0.0001,
batch_norm_decay=0.997, batch_norm_decay=0.997,
batch_norm_epsilon=1e-5, batch_norm_epsilon=1e-5,
batch_norm_scale=True): batch_norm_scale=True,
activation_fn=tf.nn.relu,
use_batch_norm=True):
"""Defines the default ResNet arg scope. """Defines the default ResNet arg scope.
TODO(gpapan): The batch-normalization related default values above are TODO(gpapan): The batch-normalization related default values above are
...@@ -215,6 +217,8 @@ def resnet_arg_scope(weight_decay=0.0001, ...@@ -215,6 +217,8 @@ def resnet_arg_scope(weight_decay=0.0001,
normalizing activations by their variance in batch normalization. normalizing activations by their variance in batch normalization.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer. activations in the batch normalization layer.
activation_fn: The activation function which is used in ResNet.
use_batch_norm: Whether or not to use batch normalization.
Returns: Returns:
An `arg_scope` to use for the resnet models. An `arg_scope` to use for the resnet models.
...@@ -230,8 +234,8 @@ def resnet_arg_scope(weight_decay=0.0001, ...@@ -230,8 +234,8 @@ def resnet_arg_scope(weight_decay=0.0001,
[slim.conv2d], [slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay), weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(), weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu, activation_fn=activation_fn,
normalizer_fn=slim.batch_norm, normalizer_fn=slim.batch_norm if use_batch_norm else None,
normalizer_params=batch_norm_params): normalizer_params=batch_norm_params):
with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope([slim.batch_norm], **batch_norm_params):
# The following implies padding='SAME' for pool1, which makes feature # The following implies padding='SAME' for pool1, which makes feature
......
...@@ -66,8 +66,14 @@ slim = tf.contrib.slim ...@@ -66,8 +66,14 @@ slim = tf.contrib.slim
@slim.add_arg_scope @slim.add_arg_scope
def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, def bottleneck(inputs,
outputs_collections=None, scope=None): depth,
depth_bottleneck,
stride,
rate=1,
outputs_collections=None,
scope=None,
use_bounded_activations=False):
"""Bottleneck residual unit variant with BN after convolutions. """Bottleneck residual unit variant with BN after convolutions.
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
...@@ -86,6 +92,8 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, ...@@ -86,6 +92,8 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
rate: An integer, rate for atrous convolution. rate: An integer, rate for atrous convolution.
outputs_collections: Collection to add the ResNet unit output. outputs_collections: Collection to add the ResNet unit output.
scope: Optional variable_scope. scope: Optional variable_scope.
use_bounded_activations: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
Returns: Returns:
The ResNet unit's output. The ResNet unit's output.
...@@ -95,8 +103,12 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, ...@@ -95,8 +103,12 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
if depth == depth_in: if depth == depth_in:
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
else: else:
shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride, shortcut = slim.conv2d(
activation_fn=None, scope='shortcut') inputs,
depth, [1, 1],
stride=stride,
activation_fn=tf.nn.relu6 if use_bounded_activations else None,
scope='shortcut')
residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
scope='conv1') scope='conv1')
...@@ -105,7 +117,12 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, ...@@ -105,7 +117,12 @@ def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
residual = slim.conv2d(residual, depth, [1, 1], stride=1, residual = slim.conv2d(residual, depth, [1, 1], stride=1,
activation_fn=None, scope='conv3') activation_fn=None, scope='conv3')
output = tf.nn.relu(shortcut + residual) if use_bounded_activations:
# Use clip_by_value to simulate bandpass activation.
residual = tf.clip_by_value(residual, -6.0, 6.0)
output = tf.nn.relu6(shortcut + residual)
else:
output = tf.nn.relu(shortcut + residual)
return slim.utils.collect_named_outputs(outputs_collections, return slim.utils.collect_named_outputs(outputs_collections,
sc.original_name_scope, sc.original_name_scope,
...@@ -119,7 +136,7 @@ def resnet_v1(inputs, ...@@ -119,7 +136,7 @@ def resnet_v1(inputs,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=False, spatial_squeeze=True,
reuse=None, reuse=None,
scope=None): scope=None):
"""Generator for v1 ResNet models. """Generator for v1 ResNet models.
......
...@@ -251,6 +251,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -251,6 +251,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v1_small'): scope='resnet_v1_small'):
"""A shallow and thin ResNet v1 for faster tests.""" """A shallow and thin ResNet v1 for faster tests."""
...@@ -266,6 +267,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -266,6 +267,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=global_pool, global_pool=global_pool,
output_stride=output_stride, output_stride=output_stride,
include_root_block=include_root_block, include_root_block=include_root_block,
spatial_squeeze=spatial_squeeze,
reuse=reuse, reuse=reuse,
scope=scope) scope=scope)
...@@ -276,6 +278,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -276,6 +278,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(inputs, num_classes, logits, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
spatial_squeeze=False,
scope='resnet') scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes]) self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
...@@ -307,6 +310,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -307,6 +310,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs, num_classes, _, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
spatial_squeeze=False,
scope='resnet') scope='resnet')
endpoint_to_shape = { endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4], 'resnet/block1': [2, 41, 41, 4],
...@@ -325,6 +329,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -325,6 +329,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
_, end_points = self._resnet_small(inputs, num_classes, _, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
include_root_block=False, include_root_block=False,
spatial_squeeze=False,
scope='resnet') scope='resnet')
endpoint_to_shape = { endpoint_to_shape = {
'resnet/block1': [2, 64, 64, 4], 'resnet/block1': [2, 64, 64, 4],
...@@ -345,6 +350,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -345,6 +350,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
num_classes, num_classes,
global_pool=global_pool, global_pool=global_pool,
output_stride=output_stride, output_stride=output_stride,
spatial_squeeze=False,
scope='resnet') scope='resnet')
endpoint_to_shape = { endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4], 'resnet/block1': [2, 41, 41, 4],
...@@ -391,6 +397,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -391,6 +397,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(inputs, num_classes, logits, _ = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
spatial_squeeze=False,
scope='resnet') scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), self.assertListEqual(logits.get_shape().as_list(),
......
...@@ -115,7 +115,7 @@ def resnet_v2(inputs, ...@@ -115,7 +115,7 @@ def resnet_v2(inputs,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=False, spatial_squeeze=True,
reuse=None, reuse=None,
scope=None): scope=None):
"""Generator for v2 (preactivation) ResNet models. """Generator for v2 (preactivation) ResNet models.
...@@ -251,7 +251,7 @@ def resnet_v2_50(inputs, ...@@ -251,7 +251,7 @@ def resnet_v2_50(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=False, spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_50'): scope='resnet_v2_50'):
"""ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
...@@ -273,7 +273,7 @@ def resnet_v2_101(inputs, ...@@ -273,7 +273,7 @@ def resnet_v2_101(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=False, spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_101'): scope='resnet_v2_101'):
"""ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
...@@ -295,7 +295,7 @@ def resnet_v2_152(inputs, ...@@ -295,7 +295,7 @@ def resnet_v2_152(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=False, spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_152'): scope='resnet_v2_152'):
"""ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
...@@ -317,7 +317,7 @@ def resnet_v2_200(inputs, ...@@ -317,7 +317,7 @@ def resnet_v2_200(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=False, spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_200'): scope='resnet_v2_200'):
"""ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
......
...@@ -251,6 +251,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -251,6 +251,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_small'): scope='resnet_v2_small'):
"""A shallow and thin ResNet v2 for faster tests.""" """A shallow and thin ResNet v2 for faster tests."""
...@@ -266,6 +267,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -266,6 +267,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
global_pool=global_pool, global_pool=global_pool,
output_stride=output_stride, output_stride=output_stride,
include_root_block=include_root_block, include_root_block=include_root_block,
spatial_squeeze=spatial_squeeze,
reuse=reuse, reuse=reuse,
scope=scope) scope=scope)
...@@ -276,6 +278,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -276,6 +278,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(inputs, num_classes, logits, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
spatial_squeeze=False,
scope='resnet') scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes]) self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
...@@ -307,6 +310,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -307,6 +310,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs, num_classes, _, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
spatial_squeeze=False,
scope='resnet') scope='resnet')
endpoint_to_shape = { endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4], 'resnet/block1': [2, 41, 41, 4],
...@@ -325,6 +329,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -325,6 +329,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
_, end_points = self._resnet_small(inputs, num_classes, _, end_points = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
include_root_block=False, include_root_block=False,
spatial_squeeze=False,
scope='resnet') scope='resnet')
endpoint_to_shape = { endpoint_to_shape = {
'resnet/block1': [2, 64, 64, 4], 'resnet/block1': [2, 64, 64, 4],
...@@ -345,6 +350,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -345,6 +350,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
num_classes, num_classes,
global_pool=global_pool, global_pool=global_pool,
output_stride=output_stride, output_stride=output_stride,
spatial_squeeze=False,
scope='resnet') scope='resnet')
endpoint_to_shape = { endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4], 'resnet/block1': [2, 41, 41, 4],
...@@ -393,6 +399,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -393,6 +399,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(inputs, num_classes, logits, _ = self._resnet_small(inputs, num_classes,
global_pool=global_pool, global_pool=global_pool,
spatial_squeeze=False,
scope='resnet') scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), self.assertListEqual(logits.get_shape().as_list(),
......
...@@ -87,8 +87,9 @@ def vgg_a(inputs, ...@@ -87,8 +87,9 @@ def vgg_a(inputs,
fc_conv_padding: the type of padding to use for the fully connected layer fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output. Otherwise, get a prediction map downsampled by a factor of 32 as an output.
the output prediction map will be (input / 32) - 6 in case of 'VALID' padding. Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns: Returns:
the last op containing the log predictions and end_points dict. the last op containing the log predictions and end_points dict.
...@@ -152,8 +153,9 @@ def vgg_16(inputs, ...@@ -152,8 +153,9 @@ def vgg_16(inputs,
fc_conv_padding: the type of padding to use for the fully connected layer fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output. Otherwise, get a prediction map downsampled by a factor of 32 as an output.
the output prediction map will be (input / 32) - 6 in case of 'VALID' padding. Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns: Returns:
the last op containing the log predictions and end_points dict. the last op containing the log predictions and end_points dict.
...@@ -217,8 +219,10 @@ def vgg_19(inputs, ...@@ -217,8 +219,10 @@ def vgg_19(inputs,
fc_conv_padding: the type of padding to use for the fully connected layer fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output. Otherwise, get a prediction map downsampled by a factor of 32 as an output.
the output prediction map will be (input / 32) - 6 in case of 'VALID' padding. Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns: Returns:
the last op containing the log predictions and end_points dict. the last op containing the log predictions and end_points dict.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment