Commit 60c3ed2e authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Update resnet (#1559)

parent fc7342bf
...@@ -178,26 +178,16 @@ def stack_blocks_dense(net, blocks, output_stride=None, ...@@ -178,26 +178,16 @@ def stack_blocks_dense(net, blocks, output_stride=None,
raise ValueError('The target output_stride cannot be reached.') raise ValueError('The target output_stride cannot be reached.')
with tf.variable_scope('unit_%d' % (i + 1), values=[net]): with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
unit_depth, unit_depth_bottleneck, unit_stride = unit
# If we have reached the target output_stride, then we need to employ # If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the # atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers. # current unit's stride for use in subsequent layers.
if output_stride is not None and current_stride == output_stride: if output_stride is not None and current_stride == output_stride:
net = block.unit_fn(net, net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
depth=unit_depth, rate *= unit.get('stride', 1)
depth_bottleneck=unit_depth_bottleneck,
stride=1,
rate=rate)
rate *= unit_stride
else: else:
net = block.unit_fn(net, net = block.unit_fn(net, rate=1, **unit)
depth=unit_depth, current_stride *= unit.get('stride', 1)
depth_bottleneck=unit_depth_bottleneck,
stride=unit_stride,
rate=1)
current_stride *= unit_stride
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
if output_stride is not None and current_stride != output_stride: if output_stride is not None and current_stride != output_stride:
......
...@@ -119,7 +119,7 @@ def resnet_v1(inputs, ...@@ -119,7 +119,7 @@ def resnet_v1(inputs,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=True, spatial_squeeze=False,
reuse=None, reuse=None,
scope=None): scope=None):
"""Generator for v1 ResNet models. """Generator for v1 ResNet models.
...@@ -205,13 +205,38 @@ def resnet_v1(inputs, ...@@ -205,13 +205,38 @@ def resnet_v1(inputs,
else: else:
logits = net logits = net
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(
end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(logits, scope='predictions') end_points['predictions'] = slim.softmax(logits, scope='predictions')
return logits, end_points return logits, end_points
resnet_v1.default_image_size = 224 resnet_v1.default_image_size = 224
def resnet_v1_block(scope, base_depth, num_units, stride):
"""Helper function for creating a resnet_v1 bottleneck block.
Args:
scope: The scope of the block.
base_depth: The depth of the bottleneck layer for each unit.
num_units: The number of units in the block.
stride: The stride of the block, implemented as a stride in the last unit.
All other units have stride=1.
Returns:
A resnet_v1 bottleneck block.
"""
return resnet_utils.Block(scope, bottleneck, [{
'depth': base_depth * 4,
'depth_bottleneck': base_depth,
'stride': 1
}] * (num_units - 1) + [{
'depth': base_depth * 4,
'depth_bottleneck': base_depth,
'stride': stride
}])
def resnet_v1_50(inputs, def resnet_v1_50(inputs,
num_classes=None, num_classes=None,
is_training=True, is_training=True,
...@@ -222,14 +247,10 @@ def resnet_v1_50(inputs, ...@@ -222,14 +247,10 @@ def resnet_v1_50(inputs,
scope='resnet_v1_50'): scope='resnet_v1_50'):
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
resnet_utils.Block( resnet_v1_block('block3', base_depth=256, num_units=6, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block(
'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)
] ]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
...@@ -248,14 +269,10 @@ def resnet_v1_101(inputs, ...@@ -248,14 +269,10 @@ def resnet_v1_101(inputs,
scope='resnet_v1_101'): scope='resnet_v1_101'):
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
resnet_utils.Block( resnet_v1_block('block3', base_depth=256, num_units=23, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block(
'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)
] ]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
...@@ -274,14 +291,11 @@ def resnet_v1_152(inputs, ...@@ -274,14 +291,11 @@ def resnet_v1_152(inputs,
scope='resnet_v1_152'): scope='resnet_v1_152'):
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v1_block('block2', base_depth=128, num_units=8, stride=2),
resnet_utils.Block( resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze, include_root_block=True, spatial_squeeze=spatial_squeeze,
...@@ -299,14 +313,11 @@ def resnet_v1_200(inputs, ...@@ -299,14 +313,11 @@ def resnet_v1_200(inputs,
scope='resnet_v1_200'): scope='resnet_v1_200'):
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v1_block('block2', base_depth=128, num_units=24, stride=2),
resnet_utils.Block( resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze, include_root_block=True, spatial_squeeze=spatial_squeeze,
......
...@@ -156,14 +156,17 @@ class ResnetUtilsTest(tf.test.TestCase): ...@@ -156,14 +156,17 @@ class ResnetUtilsTest(tf.test.TestCase):
with tf.variable_scope(scope, values=[inputs]): with tf.variable_scope(scope, values=[inputs]):
with slim.arg_scope([slim.conv2d], outputs_collections='end_points'): with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride) net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
end_points = dict(tf.get_collection('end_points')) end_points = slim.utils.convert_collection_to_dict('end_points')
return net, end_points return net, end_points
def testEndPointsV1(self): def testEndPointsV1(self):
"""Test the end points of a tiny v1 bottleneck network.""" """Test the end points of a tiny v1 bottleneck network."""
bottleneck = resnet_v1.bottleneck blocks = [
blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]), resnet_v1.resnet_v1_block(
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])] 'block1', base_depth=1, num_units=2, stride=2),
resnet_v1.resnet_v1_block(
'block2', base_depth=2, num_units=2, stride=1),
]
inputs = create_test_input(2, 32, 16, 3) inputs = create_test_input(2, 32, 16, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_plain(inputs, blocks, scope='tiny') _, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
...@@ -189,30 +192,23 @@ class ResnetUtilsTest(tf.test.TestCase): ...@@ -189,30 +192,23 @@ class ResnetUtilsTest(tf.test.TestCase):
for block in blocks: for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]): with tf.variable_scope(block.scope, 'block', [net]):
for i, unit in enumerate(block.args): for i, unit in enumerate(block.args):
depth, depth_bottleneck, stride = unit
with tf.variable_scope('unit_%d' % (i + 1), values=[net]): with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
net = block.unit_fn(net, net = block.unit_fn(net, rate=1, **unit)
depth=depth,
depth_bottleneck=depth_bottleneck,
stride=stride,
rate=1)
return net return net
def _atrousValues(self, bottleneck): def testAtrousValuesBottleneck(self):
"""Verify the values of dense feature extraction by atrous convolution. """Verify the values of dense feature extraction by atrous convolution.
Make sure that dense feature extraction by stack_blocks_dense() followed by Make sure that dense feature extraction by stack_blocks_dense() followed by
subsampling gives identical results to feature extraction at the nominal subsampling gives identical results to feature extraction at the nominal
network output stride using the simple self._stack_blocks_nondense() above. network output stride using the simple self._stack_blocks_nondense() above.
Args:
bottleneck: The bottleneck function.
""" """
block = resnet_v1.resnet_v1_block
blocks = [ blocks = [
resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]), block('block1', base_depth=1, num_units=2, stride=2),
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]), block('block2', base_depth=2, num_units=2, stride=2),
resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]), block('block3', base_depth=4, num_units=2, stride=2),
resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)]) block('block4', base_depth=8, num_units=2, stride=1),
] ]
nominal_stride = 8 nominal_stride = 8
...@@ -244,9 +240,6 @@ class ResnetUtilsTest(tf.test.TestCase): ...@@ -244,9 +240,6 @@ class ResnetUtilsTest(tf.test.TestCase):
output, expected = sess.run([output, expected]) output, expected = sess.run([output, expected])
self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4) self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
def testAtrousValuesBottleneck(self):
self._atrousValues(resnet_v1.bottleneck)
class ResnetCompleteNetworkTest(tf.test.TestCase): class ResnetCompleteNetworkTest(tf.test.TestCase):
"""Tests with complete small ResNet v1 networks.""" """Tests with complete small ResNet v1 networks."""
...@@ -261,16 +254,13 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -261,16 +254,13 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
reuse=None, reuse=None,
scope='resnet_v1_small'): scope='resnet_v1_small'):
"""A shallow and thin ResNet v1 for faster tests.""" """A shallow and thin ResNet v1 for faster tests."""
bottleneck = resnet_v1.bottleneck block = resnet_v1.resnet_v1_block
blocks = [ blocks = [
resnet_utils.Block( block('block1', base_depth=1, num_units=3, stride=2),
'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]), block('block2', base_depth=2, num_units=3, stride=2),
resnet_utils.Block( block('block3', base_depth=4, num_units=3, stride=2),
'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]), block('block4', base_depth=8, num_units=2, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(32, 8, 1)] * 2)]
return resnet_v1.resnet_v1(inputs, blocks, num_classes, return resnet_v1.resnet_v1(inputs, blocks, num_classes,
is_training=is_training, is_training=is_training,
global_pool=global_pool, global_pool=global_pool,
......
...@@ -25,6 +25,8 @@ introduced by: ...@@ -25,6 +25,8 @@ introduced by:
The key difference of the full preactivation 'v2' variant compared to the The key difference of the full preactivation 'v2' variant compared to the
'v1' variant in [1] is the use of batch normalization before every weight layer. 'v1' variant in [1] is the use of batch normalization before every weight layer.
Another difference is that 'v2' ResNets do not include an activation function in
the main pathway. Also see [2; Fig. 4e].
Typical use: Typical use:
...@@ -115,7 +117,7 @@ def resnet_v2(inputs, ...@@ -115,7 +117,7 @@ def resnet_v2(inputs,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
include_root_block=True, include_root_block=True,
spatial_squeeze=True, spatial_squeeze=False,
reuse=None, reuse=None,
scope=None): scope=None):
"""Generator for v2 (preactivation) ResNet models. """Generator for v2 (preactivation) ResNet models.
...@@ -212,31 +214,54 @@ def resnet_v2(inputs, ...@@ -212,31 +214,54 @@ def resnet_v2(inputs,
else: else:
logits = net logits = net
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(
end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(logits, scope='predictions') end_points['predictions'] = slim.softmax(logits, scope='predictions')
return logits, end_points return logits, end_points
resnet_v2.default_image_size = 224 resnet_v2.default_image_size = 224
def resnet_v2_block(scope, base_depth, num_units, stride):
"""Helper function for creating a resnet_v2 bottleneck block.
Args:
scope: The scope of the block.
base_depth: The depth of the bottleneck layer for each unit.
num_units: The number of units in the block.
stride: The stride of the block, implemented as a stride in the last unit.
All other units have stride=1.
Returns:
A resnet_v2 bottleneck block.
"""
return resnet_utils.Block(scope, bottleneck, [{
'depth': base_depth * 4,
'depth_bottleneck': base_depth,
'stride': 1
}] * (num_units - 1) + [{
'depth': base_depth * 4,
'depth_bottleneck': base_depth,
'stride': stride
}])
resnet_v2.default_image_size = 224
def resnet_v2_50(inputs, def resnet_v2_50(inputs,
num_classes=None, num_classes=None,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True, spatial_squeeze=False,
reuse=None, reuse=None,
scope='resnet_v2_50'): scope='resnet_v2_50'):
"""ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
resnet_utils.Block( resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze, include_root_block=True, spatial_squeeze=spatial_squeeze,
...@@ -249,19 +274,16 @@ def resnet_v2_101(inputs, ...@@ -249,19 +274,16 @@ def resnet_v2_101(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True, spatial_squeeze=False,
reuse=None, reuse=None,
scope='resnet_v2_101'): scope='resnet_v2_101'):
"""ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
resnet_utils.Block( resnet_v2_block('block3', base_depth=256, num_units=23, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze, include_root_block=True, spatial_squeeze=spatial_squeeze,
...@@ -274,19 +296,16 @@ def resnet_v2_152(inputs, ...@@ -274,19 +296,16 @@ def resnet_v2_152(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True, spatial_squeeze=False,
reuse=None, reuse=None,
scope='resnet_v2_152'): scope='resnet_v2_152'):
"""ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v2_block('block2', base_depth=128, num_units=8, stride=2),
resnet_utils.Block( resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze, include_root_block=True, spatial_squeeze=spatial_squeeze,
...@@ -299,19 +318,16 @@ def resnet_v2_200(inputs, ...@@ -299,19 +318,16 @@ def resnet_v2_200(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True, spatial_squeeze=False,
reuse=None, reuse=None,
scope='resnet_v2_200'): scope='resnet_v2_200'):
"""ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
blocks = [ blocks = [
resnet_utils.Block( resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), resnet_v2_block('block2', base_depth=128, num_units=24, stride=2),
resnet_utils.Block( resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze, include_root_block=True, spatial_squeeze=spatial_squeeze,
......
...@@ -156,14 +156,17 @@ class ResnetUtilsTest(tf.test.TestCase): ...@@ -156,14 +156,17 @@ class ResnetUtilsTest(tf.test.TestCase):
with tf.variable_scope(scope, values=[inputs]): with tf.variable_scope(scope, values=[inputs]):
with slim.arg_scope([slim.conv2d], outputs_collections='end_points'): with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride) net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
end_points = dict(tf.get_collection('end_points')) end_points = slim.utils.convert_collection_to_dict('end_points')
return net, end_points return net, end_points
def testEndPointsV2(self): def testEndPointsV2(self):
"""Test the end points of a tiny v2 bottleneck network.""" """Test the end points of a tiny v2 bottleneck network."""
bottleneck = resnet_v2.bottleneck blocks = [
blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]), resnet_v2.resnet_v2_block(
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])] 'block1', base_depth=1, num_units=2, stride=2),
resnet_v2.resnet_v2_block(
'block2', base_depth=2, num_units=2, stride=1),
]
inputs = create_test_input(2, 32, 16, 3) inputs = create_test_input(2, 32, 16, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()): with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_plain(inputs, blocks, scope='tiny') _, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
...@@ -189,30 +192,23 @@ class ResnetUtilsTest(tf.test.TestCase): ...@@ -189,30 +192,23 @@ class ResnetUtilsTest(tf.test.TestCase):
for block in blocks: for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]): with tf.variable_scope(block.scope, 'block', [net]):
for i, unit in enumerate(block.args): for i, unit in enumerate(block.args):
depth, depth_bottleneck, stride = unit
with tf.variable_scope('unit_%d' % (i + 1), values=[net]): with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
net = block.unit_fn(net, net = block.unit_fn(net, rate=1, **unit)
depth=depth,
depth_bottleneck=depth_bottleneck,
stride=stride,
rate=1)
return net return net
def _atrousValues(self, bottleneck): def testAtrousValuesBottleneck(self):
"""Verify the values of dense feature extraction by atrous convolution. """Verify the values of dense feature extraction by atrous convolution.
Make sure that dense feature extraction by stack_blocks_dense() followed by Make sure that dense feature extraction by stack_blocks_dense() followed by
subsampling gives identical results to feature extraction at the nominal subsampling gives identical results to feature extraction at the nominal
network output stride using the simple self._stack_blocks_nondense() above. network output stride using the simple self._stack_blocks_nondense() above.
Args:
bottleneck: The bottleneck function.
""" """
block = resnet_v2.resnet_v2_block
blocks = [ blocks = [
resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]), block('block1', base_depth=1, num_units=2, stride=2),
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]), block('block2', base_depth=2, num_units=2, stride=2),
resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]), block('block3', base_depth=4, num_units=2, stride=2),
resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)]) block('block4', base_depth=8, num_units=2, stride=1),
] ]
nominal_stride = 8 nominal_stride = 8
...@@ -244,9 +240,6 @@ class ResnetUtilsTest(tf.test.TestCase): ...@@ -244,9 +240,6 @@ class ResnetUtilsTest(tf.test.TestCase):
output, expected = sess.run([output, expected]) output, expected = sess.run([output, expected])
self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4) self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
def testAtrousValuesBottleneck(self):
self._atrousValues(resnet_v2.bottleneck)
class ResnetCompleteNetworkTest(tf.test.TestCase): class ResnetCompleteNetworkTest(tf.test.TestCase):
"""Tests with complete small ResNet v2 networks.""" """Tests with complete small ResNet v2 networks."""
...@@ -261,16 +254,13 @@ class ResnetCompleteNetworkTest(tf.test.TestCase): ...@@ -261,16 +254,13 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
reuse=None, reuse=None,
scope='resnet_v2_small'): scope='resnet_v2_small'):
"""A shallow and thin ResNet v2 for faster tests.""" """A shallow and thin ResNet v2 for faster tests."""
bottleneck = resnet_v2.bottleneck block = resnet_v2.resnet_v2_block
blocks = [ blocks = [
resnet_utils.Block( block('block1', base_depth=1, num_units=3, stride=2),
'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]), block('block2', base_depth=2, num_units=3, stride=2),
resnet_utils.Block( block('block3', base_depth=4, num_units=3, stride=2),
'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]), block('block4', base_depth=8, num_units=2, stride=1),
resnet_utils.Block( ]
'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(32, 8, 1)] * 2)]
return resnet_v2.resnet_v2(inputs, blocks, num_classes, return resnet_v2.resnet_v2(inputs, blocks, num_classes,
is_training=is_training, is_training=is_training,
global_pool=global_pool, global_pool=global_pool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment