Unverified Commit f6bba08d authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Add resnet v1 and numerical unit tests to official resnet. (#3484)

parent 6e3e5c38
......@@ -4,9 +4,11 @@ Deep residual networks, or ResNets for short, provided the breakthrough idea of
See the following papers for more background:
[Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
[1] [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
[Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
[2] [Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
In code v1 refers to the resnet defined in [1], while v2 correspondingly refers to [2]. The principle difference between the two versions is that v1 applies batch normalization and activation after convolution, while v2 applies batch normalization, then activation, and finally convolution. A schematic comparison is presented in Figure 1 (left) of [2].
Please proceed according to which dataset you would like to train/evaluate on:
......
......@@ -140,7 +140,8 @@ def get_synth_input_fn():
###############################################################################
class Cifar10Model(resnet.Model):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet.DEFAULT_VERSION):
"""These are the parameters that work for CIFAR-10 data.
Args:
......@@ -149,6 +150,8 @@ class Cifar10Model(resnet.Model):
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
"""
if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size)
......@@ -157,6 +160,7 @@ class Cifar10Model(resnet.Model):
super(Cifar10Model, self).__init__(
resnet_size=resnet_size,
bottleneck=False,
num_classes=num_classes,
num_filters=16,
kernel_size=3,
......@@ -165,10 +169,10 @@ class Cifar10Model(resnet.Model):
first_pool_stride=None,
second_pool_size=8,
second_pool_stride=1,
block_fn=resnet.building_block,
block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2],
final_size=64,
version=version,
data_format=data_format)
......@@ -199,6 +203,7 @@ def cifar10_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'])
......
......@@ -64,7 +64,7 @@ class BaseTest(tf.test.TestCase):
for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, multi_gpu=False):
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
......@@ -74,6 +74,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'multi_gpu': multi_gpu
})
......@@ -96,28 +97,43 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_cifar10_model_fn_train_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_cifar10_model_fn_train_mode_multi_gpu(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True)
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_cifar10_model_fn_eval_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL)
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
def test_cifar10_model_fn_predict_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
def test_cifar10model_shape(self):
batch_size = 135
num_classes = 246
model = cifar10_main.Cifar10Model(
32, data_format='channels_last', num_classes=num_classes)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)
for version in (1, 2):
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
self.assertAllEqual(output.shape, (batch_size, num_classes))
if __name__ == '__main__':
......
......@@ -164,7 +164,8 @@ def get_synth_input_fn():
###############################################################################
class ImagenetModel(resnet.Model):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet.DEFAULT_VERSION):
"""These are the parameters that work for Imagenet data.
Args:
......@@ -173,18 +174,21 @@ class ImagenetModel(resnet.Model):
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
"""
# For bigger models, we want to use "bottleneck" layers
if resnet_size < 50:
block_fn = resnet.building_block
bottleneck = False
final_size = 512
else:
block_fn = resnet.bottleneck_block
bottleneck = True
final_size = 2048
super(ImagenetModel, self).__init__(
resnet_size=resnet_size,
bottleneck=bottleneck,
num_classes=num_classes,
num_filters=64,
kernel_size=7,
......@@ -193,10 +197,10 @@ class ImagenetModel(resnet.Model):
first_pool_stride=2,
second_pool_size=7,
second_pool_stride=1,
block_fn=block_fn,
block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2],
final_size=final_size,
version=version,
data_format=data_format)
......@@ -236,6 +240,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_filter_fn=None,
multi_gpu=params['multi_gpu'])
......
......@@ -31,7 +31,7 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase):
def tensor_shapes_helper(self, resnet_size, with_gpu=False):
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
"""Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape):
"""Returns the expected dimensions depending on if a
......@@ -49,7 +49,8 @@ class BaseTest(tf.test.TestCase):
use_gpu=with_gpu, force_gpu=with_gpu):
model = imagenet_main.ImagenetModel(
resnet_size,
data_format='channels_first' if with_gpu else 'channels_last')
data_format='channels_first' if with_gpu else 'channels_last',
version=version)
inputs = tf.random_uniform([1, 224, 224, 3])
output = model(inputs, training=True)
......@@ -83,49 +84,91 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def test_tensor_shapes_resnet_18(self):
self.tensor_shapes_helper(18)
def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, version=1)
def test_tensor_shapes_resnet_34(self):
self.tensor_shapes_helper(34)
def test_tensor_shapes_resnet_18_v2(self):
self.tensor_shapes_helper(18, version=2)
def test_tensor_shapes_resnet_50(self):
self.tensor_shapes_helper(50)
def test_tensor_shapes_resnet_34_v1(self):
self.tensor_shapes_helper(34, version=1)
def test_tensor_shapes_resnet_101(self):
self.tensor_shapes_helper(101)
def test_tensor_shapes_resnet_34_v2(self):
self.tensor_shapes_helper(34, version=2)
def test_tensor_shapes_resnet_152(self):
self.tensor_shapes_helper(152)
def test_tensor_shapes_resnet_50_v1(self):
self.tensor_shapes_helper(50, version=1)
def test_tensor_shapes_resnet_200(self):
self.tensor_shapes_helper(200)
def test_tensor_shapes_resnet_50_v2(self):
self.tensor_shapes_helper(50, version=2)
def test_tensor_shapes_resnet_101_v1(self):
self.tensor_shapes_helper(101, version=1)
def test_tensor_shapes_resnet_101_v2(self):
self.tensor_shapes_helper(101, version=2)
def test_tensor_shapes_resnet_152_v1(self):
self.tensor_shapes_helper(152, version=1)
def test_tensor_shapes_resnet_152_v2(self):
self.tensor_shapes_helper(152, version=2)
def test_tensor_shapes_resnet_200_v1(self):
self.tensor_shapes_helper(200, version=1)
def test_tensor_shapes_resnet_200_v2(self):
self.tensor_shapes_helper(200, version=2)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v1(self):
self.tensor_shapes_helper(18, version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v2(self):
self.tensor_shapes_helper(18, version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v1(self):
self.tensor_shapes_helper(34, version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v2(self):
self.tensor_shapes_helper(34, version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu(self):
self.tensor_shapes_helper(18, True)
def test_tensor_shapes_resnet_50_with_gpu_v1(self):
self.tensor_shapes_helper(50, version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu(self):
self.tensor_shapes_helper(34, True)
def test_tensor_shapes_resnet_50_with_gpu_v2(self):
self.tensor_shapes_helper(50, version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu(self):
self.tensor_shapes_helper(50, True)
def test_tensor_shapes_resnet_101_with_gpu_v1(self):
self.tensor_shapes_helper(101, version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu(self):
self.tensor_shapes_helper(101, True)
def test_tensor_shapes_resnet_101_with_gpu_v2(self):
self.tensor_shapes_helper(101, version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu(self):
self.tensor_shapes_helper(152, True)
def test_tensor_shapes_resnet_152_with_gpu_v1(self):
self.tensor_shapes_helper(152, version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu(self):
self.tensor_shapes_helper(200, True)
def test_tensor_shapes_resnet_152_with_gpu_v2(self):
self.tensor_shapes_helper(152, version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, multi_gpu=False):
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v1(self):
self.tensor_shapes_helper(200, version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step()
......@@ -138,6 +181,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 50,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'multi_gpu': multi_gpu,
})
......@@ -160,28 +204,43 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_resnet_model_fn_train_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
def test_resnet_model_fn_train_mode_multi_gpu_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
def test_resnet_model_fn_train_mode_multi_gpu(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True)
def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
def test_resnet_model_fn_eval_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL)
def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
def test_resnet_model_fn_predict_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
def test_imagenetmodel_shape(self):
batch_size = 135
num_classes = 246
model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes)
fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True)
for version in (1, 2):
model = imagenet_main.ImagenetModel(50, data_format='channels_last',
num_classes=num_classes, version=version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
self.assertAllEqual(output.shape, (batch_size, num_classes))
if __name__ == '__main__':
......
......@@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions for the preactivation form of Residual Networks
(also known as ResNet v2).
"""Contains definitions for Residual Networks.
Residual networks (ResNets) were originally proposed in:
Residual networks ('v1' ResNets) were originally proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
The full preactivation 'v2' ResNet variant implemented in this module was
introduced by:
The full preactivation 'v2' ResNet variant was introduced by:
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
......@@ -41,6 +39,8 @@ import tensorflow as tf
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2
################################################################################
# Functions for input processing.
......@@ -139,18 +139,16 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
################################################################################
# Functions building the ResNet model.
# Convenience functions for building the ResNet model.
################################################################################
def batch_norm_relu(inputs, training, data_format):
"""Performs a batch normalization followed by a ReLU."""
def batch_norm(inputs, training, data_format):
"""Performs a batch normalization using a standard set of parameters."""
# We set fused=True for a significant performance boost. See
# https://www.tensorflow.org/performance/performance_guide#common_fused_ops
inputs = tf.layers.batch_normalization(
return tf.layers.batch_normalization(
inputs=inputs, axis=1 if data_format == 'channels_first' else 3,
momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True,
scale=True, training=training, fused=True)
inputs = tf.nn.relu(inputs)
return inputs
def fixed_padding(inputs, kernel_size, data_format):
......@@ -194,9 +192,16 @@ def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
data_format=data_format)
def building_block(inputs, filters, training, projection_shortcut, strides,
data_format):
"""Standard building block for residual networks with BN before convolutions.
################################################################################
# ResNet block definitions.
################################################################################
def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
data_format):
"""
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
......@@ -214,34 +219,40 @@ def building_block(inputs, filters, training, projection_shortcut, strides,
The output tensor of the block.
"""
shortcut = inputs
inputs = batch_norm_relu(inputs, training, data_format)
# The projection shortcut should come after the first batch norm and ReLU
# since it performs a 1x1 convolution.
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
shortcut = batch_norm(inputs=shortcut, training=training,
data_format=data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = batch_norm_relu(inputs, training, data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs += shortcut
inputs = tf.nn.relu(inputs)
return inputs + shortcut
return inputs
def bottleneck_block(inputs, filters, training, projection_shortcut,
strides, data_format):
"""Bottleneck block variant for residual networks with BN before convolutions.
def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
data_format):
"""
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
......@@ -254,23 +265,103 @@ def bottleneck_block(inputs, filters, training, projection_shortcut,
The output tensor of the block.
"""
shortcut = inputs
inputs = batch_norm_relu(inputs, training, data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
# The projection shortcut should come after the first batch norm and ReLU
# since it performs a 1x1 convolution.
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
return inputs + shortcut
def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,
strides, data_format):
"""
Similar to _building_block_v1(), except using the "bottleneck" blocks
described in:
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
"""
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
shortcut = batch_norm(inputs=shortcut, training=training,
data_format=data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = batch_norm_relu(inputs, training, data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs += shortcut
inputs = tf.nn.relu(inputs)
return inputs
def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
strides, data_format):
"""
Similar to _building_block_v2(), except using the "bottleneck" blocks
described in:
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
adapted to the ordering conventions of:
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
"""
shortcut = inputs
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
# The projection shortcut should come after the first batch norm and ReLU
# since it performs a 1x1 convolution.
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = batch_norm_relu(inputs, training, data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1,
data_format=data_format)
......@@ -278,14 +369,15 @@ def bottleneck_block(inputs, filters, training, projection_shortcut,
return inputs + shortcut
def block_layer(inputs, filters, block_fn, blocks, strides, training, name,
data_format):
def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides,
training, name, data_format):
"""Creates one layer of blocks for the ResNet model.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the first convolution of the layer.
bottleneck: Is the block created a bottleneck block.
block_fn: The block to use within the model, either `building_block` or
`bottleneck_block`.
blocks: The number of blocks contained in the layer.
......@@ -299,8 +391,9 @@ def block_layer(inputs, filters, block_fn, blocks, strides, training, name,
Returns:
The output tensor of the block layer.
"""
# Bottleneck blocks end with 4x the number of filters as they start with
filters_out = 4 * filters if block_fn is bottleneck_block else filters
filters_out = filters * 4 if bottleneck else filters
def projection_shortcut(inputs):
return conv2d_fixed_padding(
......@@ -318,17 +411,19 @@ def block_layer(inputs, filters, block_fn, blocks, strides, training, name,
class Model(object):
"""Base class for building the Resnet v2 Model.
"""Base class for building the Resnet Model.
"""
def __init__(self, resnet_size, num_classes, num_filters, kernel_size,
def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
kernel_size,
conv_stride, first_pool_size, first_pool_stride,
second_pool_size, second_pool_stride, block_fn, block_sizes,
block_strides, final_size, data_format=None):
second_pool_size, second_pool_stride, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None):
"""Creates a model for classifying an image.
Args:
resnet_size: A single integer for the size of the ResNet model.
bottleneck: Use regular blocks or bottleneck blocks.
num_classes: The number of classes used as labels.
num_filters: The number of filters to use for the first block layer
of the model. This number is then doubled for each subsequent block
......@@ -341,14 +436,14 @@ class Model(object):
if first_pool_size is None.
second_pool_size: Pool size to be used for the second pooling layer.
second_pool_stride: stride size for the final pooling layer
block_fn: Which block layer function should be used? Pass in one of
the two functions defined above: building_block or bottleneck_block
block_sizes: A list containing n values, where n is the number of sets of
block layers desired. Each value should be the number of blocks in the
i-th set.
block_strides: List of integers representing the desired stride size for
each of the sets of block layers. Should be same length as block_sizes.
final_size: The expected size of the model after the second pooling.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
"""
......@@ -358,6 +453,23 @@ class Model(object):
data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
self.resnet_version = version
if version not in (1, 2):
raise ValueError(
"Resnet version should be 1 or 2. See README for citations.")
self.bottleneck = bottleneck
if bottleneck:
if version == 1:
self.block_fn = _bottleneck_block_v1
else:
self.block_fn = _bottleneck_block_v2
else:
if version == 1:
self.block_fn = _building_block_v1
else:
self.block_fn = _building_block_v2
self.data_format = data_format
self.num_classes = num_classes
self.num_filters = num_filters
......@@ -367,7 +479,6 @@ class Model(object):
self.first_pool_stride = first_pool_stride
self.second_pool_size = second_pool_size
self.second_pool_stride = second_pool_stride
self.block_fn = block_fn
self.block_sizes = block_sizes
self.block_strides = block_strides
self.final_size = final_size
......@@ -405,12 +516,13 @@ class Model(object):
for i, num_blocks in enumerate(self.block_sizes):
num_filters = self.num_filters * (2**i)
inputs = block_layer(
inputs=inputs, filters=num_filters, block_fn=self.block_fn,
blocks=num_blocks, strides=self.block_strides[i],
training=training, name='block_layer{}'.format(i + 1),
data_format=self.data_format)
inputs=inputs, filters=num_filters, bottleneck=self.bottleneck,
block_fn=self.block_fn, blocks=num_blocks,
strides=self.block_strides[i], training=training,
name='block_layer{}'.format(i + 1), data_format=self.data_format)
inputs = batch_norm_relu(inputs, training, self.data_format)
inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs)
inputs = tf.layers.average_pooling2d(
inputs=inputs, pool_size=self.second_pool_size,
strides=self.second_pool_stride, padding='VALID',
......@@ -463,7 +575,7 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, loss_filter_fn=None, multi_gpu=False):
data_format, version, loss_filter_fn=None, multi_gpu=False):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
......@@ -487,6 +599,8 @@ def resnet_model_fn(features, labels, mode, model_class,
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
......@@ -502,7 +616,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
model = model_class(resnet_size, data_format)
model = model_class(resnet_size, data_format, version=version)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
predictions = {
......@@ -628,6 +742,7 @@ def resnet_main(flags, model_function, input_function):
'data_format': flags.data_format,
'batch_size': flags.batch_size,
'multi_gpu': flags.multi_gpu,
'version': flags.version,
})
for _ in range(flags.train_epochs // flags.epochs_per_eval):
......@@ -710,6 +825,12 @@ class ResnetArgParser(argparse.ArgumentParser):
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs.')
self.add_argument(
'-v', '--version', type=int, choices=[1, 2], dest="version",
default=DEFAULT_VERSION,
help="Version of ResNet. (1 or 2) See README.md for details."
)
# Advanced args
self.add_argument(
'--use_synthetic_data', action='store_true',
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import resnet
import tensorflow as tf
class BlockTest(tf.test.TestCase):
def dense_run(self, tf_seed):
"""Simple generation of one random float and a single node dense network.
The subsequent more involved tests depend on the ability to correctly seed
TensorFlow. In the event that that process does not function as expected,
the simple dense tests will fail indicating that the issue is with the
tests rather than the ResNet functions.
Args:
tf_seed: Random seed for TensorFlow
Returns:
The generated random number and result of the dense network.
"""
with self.test_session(graph=tf.Graph()) as sess:
tf.set_random_seed(tf_seed)
x = tf.random_uniform((1, 1))
y = tf.layers.dense(inputs=x, units=1)
init = tf.global_variables_initializer()
sess.run(init)
return x.eval()[0, 0], y.eval()[0, 0]
def make_projection(self, filters_out, strides, data_format):
"""1D convolution with stride projector.
Args:
filters_out: Number of filters in the projection.
strides: Stride length for convolution.
data_format: channels_first or channels_last
Returns:
A 1 wide CNN projector function.
"""
def projection_shortcut(inputs):
return resnet.conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides,
data_format=data_format)
return projection_shortcut
def resnet_block_run(self, tf_seed, batch_size, bottleneck, projection,
version, width, channels):
"""Test whether resnet block construction has changed.
This function runs ResNet block construction under a variety of different
conditions.
Args:
tf_seed: Random seed for TensorFlow
batch_size: Number of points in the fake image. This is needed due to
batch normalization.
bottleneck: Whether or not to use bottleneck layers.
projection: Whether or not to project the input.
version: Which version of ResNet to test.
width: The width of the fake image.
channels: The number of channels in the fake image.
Returns:
The size of the block output, as well as several check values.
"""
data_format = "channels_last"
if version == 1:
block_fn = resnet._building_block_v1
if bottleneck:
block_fn = resnet._bottleneck_block_v1
else:
block_fn = resnet._building_block_v2
if bottleneck:
block_fn = resnet._bottleneck_block_v2
with self.test_session(graph=tf.Graph()) as sess:
tf.set_random_seed(tf_seed)
strides = 1
channels_out = channels
projection_shortcut = None
if projection:
strides = 2
channels_out *= strides
projection_shortcut = self.make_projection(
filters_out=channels_out, strides=strides, data_format=data_format)
filters = channels_out
if bottleneck:
filters = channels_out // 4
x = tf.random_uniform((batch_size, width, width, channels))
y = block_fn(inputs=x, filters=filters, training=True,
projection_shortcut=projection_shortcut, strides=strides,
data_format=data_format)
init = tf.global_variables_initializer()
sess.run(init)
y_array = y.eval()
y_flat = y_array.flatten()
return y_array.shape, (y_flat[0], y_flat[-1], np.sum(y_flat))
def test_dense_0(self):
"""Sanity check 0 on dense layer."""
computed = self.dense_run(1813835975)
tf.assert_equal(computed, (0.8760674, 0.2547844))
def test_dense_1(self):
"""Sanity check 1 on dense layer."""
computed = self.dense_run(3574260356)
tf.assert_equal(computed, (0.75590825, 0.5339718))
def test_bottleneck_v1_width_32_channels_64_batch_size_32_with_proj(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
599400476, batch_size=32, bottleneck=True, projection=True,
version=1, width=32, channels=64)
tf.assert_equal(computed_size, (32, 16, 16, 128))
tf.assert_equal(computed_values, (0.0, 0.92648625, 587702.4))
def test_bottleneck_v2_width_32_channels_64_batch_size_32_with_proj(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
309580726, batch_size=32, bottleneck=True, projection=True,
version=2, width=32, channels=64)
tf.assert_equal(computed_size, (32, 16, 16, 128))
tf.assert_equal(computed_values, (-1.8759897, -0.5546854, -12860.312))
def test_bottleneck_v1_width_32_channels_64_batch_size_32(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
1969060699, batch_size=32, bottleneck=True, projection=False,
version=1, width=32, channels=64)
tf.assert_equal(computed_size, (32, 32, 32, 64))
tf.assert_equal(computed_values, (0.10141289, 0.0, 1483393.0))
def test_bottleneck_v2_width_32_channels_64_batch_size_32(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
1716369119, batch_size=32, bottleneck=True, projection=False,
version=2, width=32, channels=64)
tf.assert_equal(computed_size, (32, 32, 32, 64))
tf.assert_equal(computed_values, (1.4106897, 0.7455499, 834762.75))
def test_building_v1_width_32_channels_64_batch_size_32_with_proj(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
1455996458, batch_size=32, bottleneck=False, projection=True,
version=1, width=32, channels=64)
tf.assert_equal(computed_size, (32, 16, 16, 128))
tf.assert_equal(computed_values, (0.0, 0.0, 591701.3))
def test_building_v2_width_32_channels_64_batch_size_32_with_proj(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
2770738568, batch_size=32, bottleneck=False, projection=True,
version=2, width=32, channels=64)
tf.assert_equal(computed_size, (32, 16, 16, 128))
tf.assert_equal(computed_values, (-0.1908517, 0.2792631, -45776.055))
def test_building_v1_width_32_channels_64_batch_size_32(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
1262621774, batch_size=32, bottleneck=False, projection=False,
version=1, width=32, channels=64)
tf.assert_equal(computed_size, (32, 32, 32, 64))
tf.assert_equal(computed_values, (0.0, 0.0, 1493558.9))
def test_building_v2_width_32_channels_64_batch_size_32(self):
"""Test of a single ResNet block."""
computed_size, computed_values = self.resnet_block_run(
3856195393, batch_size=32, bottleneck=False, projection=False,
version=2, width=32, channels=64)
tf.assert_equal(computed_size, (32, 32, 32, 64))
tf.assert_equal(computed_values, (-0.12920928, 0.38566422, 1157867.9))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment