Unverified Commit 5be3c064 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Rename --version or --resnet_version (#4165)

* rename --version flag and fix tests to correctly specify version rather than verbosity

* rename version to resnet_version throughout

* fix bugs

* delint

* missed layer_test

* fix indent
parent eb0c0dfd
...@@ -140,7 +140,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -140,7 +140,7 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data.""" """Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
...@@ -150,8 +150,8 @@ class Cifar10Model(resnet_model.Model): ...@@ -150,8 +150,8 @@ class Cifar10Model(resnet_model.Model):
data format to use when setting up the model. data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network
See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations. dtype: The TensorFlow dtype to use for calculations.
Raises: Raises:
...@@ -174,7 +174,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -174,7 +174,7 @@ class Cifar10Model(resnet_model.Model):
block_sizes=[num_blocks] * 3, block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2], block_strides=[1, 2, 2],
final_size=64, final_size=64,
version=version, resnet_version=resnet_version,
data_format=data_format, data_format=data_format,
dtype=dtype dtype=dtype
) )
...@@ -211,7 +211,7 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -211,7 +211,7 @@ def cifar10_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
version=params['version'], resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn, loss_filter_fn=loss_filter_fn,
dtype=params['dtype'] dtype=params['dtype']
......
...@@ -76,7 +76,7 @@ class BaseTest(tf.test.TestCase): ...@@ -76,7 +76,7 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, version, dtype): def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn() input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
...@@ -87,7 +87,7 @@ class BaseTest(tf.test.TestCase): ...@@ -87,7 +87,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 32, 'resnet_size': 32,
'data_format': 'channels_last', 'data_format': 'channels_last',
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'version': version, 'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1, 'loss_scale': 128 if dtype == tf.float16 else 1,
}) })
...@@ -111,56 +111,57 @@ class BaseTest(tf.test.TestCase): ...@@ -111,56 +111,57 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32) dtype=tf.float32)
def test_cifar10_model_fn_trainmode__v2(self): def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32) dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v1(self): def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32) dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self): def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32) dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self): def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
dtype=tf.float32) resnet_version=1, dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
dtype=tf.float32) resnet_version=2, dtype=tf.float32)
def _test_cifar10model_shape(self, version): def _test_cifar10model_shape(self, resnet_version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
model = cifar10_main.Cifar10Model(32, data_format='channels_last', model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version) num_classes=num_classes,
resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS]) fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True) output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10model_shape_v1(self): def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1) self._test_cifar10model_shape(resnet_version=1)
def test_cifar10model_shape_v2(self): def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2) self._test_cifar10model_shape(resnet_version=2)
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1'] extra_flags=['-resnet_version', '1']
) )
def test_cifar10_end_to_end_synthetic_v2(self): def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2'] extra_flags=['-resnet_version', '2']
) )
......
...@@ -197,7 +197,7 @@ class ImagenetModel(resnet_model.Model): ...@@ -197,7 +197,7 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data.""" """Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
...@@ -207,8 +207,8 @@ class ImagenetModel(resnet_model.Model): ...@@ -207,8 +207,8 @@ class ImagenetModel(resnet_model.Model):
data format to use when setting up the model. data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network
See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations. dtype: The TensorFlow dtype to use for calculations.
""" """
...@@ -232,7 +232,7 @@ class ImagenetModel(resnet_model.Model): ...@@ -232,7 +232,7 @@ class ImagenetModel(resnet_model.Model):
block_sizes=_get_block_sizes(resnet_size), block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2], block_strides=[1, 2, 2, 2],
final_size=final_size, final_size=final_size,
version=version, resnet_version=resnet_version,
data_format=data_format, data_format=data_format,
dtype=dtype dtype=dtype
) )
...@@ -289,7 +289,7 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -289,7 +289,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
version=params['version'], resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=None, loss_filter_fn=None,
dtype=params['dtype'] dtype=params['dtype']
......
...@@ -41,7 +41,7 @@ class BaseTest(tf.test.TestCase): ...@@ -41,7 +41,7 @@ class BaseTest(tf.test.TestCase):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
def _tensor_shapes_helper(self, resnet_size, version, dtype, with_gpu): def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape): def reshape(shape):
"""Returns the expected dimensions depending on if a GPU is being used.""" """Returns the expected dimensions depending on if a GPU is being used."""
...@@ -59,7 +59,7 @@ class BaseTest(tf.test.TestCase): ...@@ -59,7 +59,7 @@ class BaseTest(tf.test.TestCase):
model = imagenet_main.ImagenetModel( model = imagenet_main.ImagenetModel(
resnet_size=resnet_size, resnet_size=resnet_size,
data_format='channels_first' if with_gpu else 'channels_last', data_format='channels_first' if with_gpu else 'channels_last',
version=version, resnet_version=resnet_version,
dtype=dtype dtype=dtype
) )
inputs = tf.random_uniform([1, 224, 224, 3]) inputs = tf.random_uniform([1, 224, 224, 3])
...@@ -95,97 +95,99 @@ class BaseTest(tf.test.TestCase): ...@@ -95,97 +95,99 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): def tensor_shapes_helper(self, resnet_size, resnet_version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size, version=version, self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float32, with_gpu=with_gpu) dtype=tf.float32, with_gpu=with_gpu)
self._tensor_shapes_helper(resnet_size=resnet_size, version=version, self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float16, with_gpu=with_gpu) dtype=tf.float16, with_gpu=with_gpu)
def test_tensor_shapes_resnet_18_v1(self): def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, version=1) self.tensor_shapes_helper(18, resnet_version=1)
def test_tensor_shapes_resnet_18_v2(self): def test_tensor_shapes_resnet_18_v2(self):
self.tensor_shapes_helper(18, version=2) self.tensor_shapes_helper(18, resnet_version=2)
def test_tensor_shapes_resnet_34_v1(self): def test_tensor_shapes_resnet_34_v1(self):
self.tensor_shapes_helper(34, version=1) self.tensor_shapes_helper(34, resnet_version=1)
def test_tensor_shapes_resnet_34_v2(self): def test_tensor_shapes_resnet_34_v2(self):
self.tensor_shapes_helper(34, version=2) self.tensor_shapes_helper(34, resnet_version=2)
def test_tensor_shapes_resnet_50_v1(self): def test_tensor_shapes_resnet_50_v1(self):
self.tensor_shapes_helper(50, version=1) self.tensor_shapes_helper(50, resnet_version=1)
def test_tensor_shapes_resnet_50_v2(self): def test_tensor_shapes_resnet_50_v2(self):
self.tensor_shapes_helper(50, version=2) self.tensor_shapes_helper(50, resnet_version=2)
def test_tensor_shapes_resnet_101_v1(self): def test_tensor_shapes_resnet_101_v1(self):
self.tensor_shapes_helper(101, version=1) self.tensor_shapes_helper(101, resnet_version=1)
def test_tensor_shapes_resnet_101_v2(self): def test_tensor_shapes_resnet_101_v2(self):
self.tensor_shapes_helper(101, version=2) self.tensor_shapes_helper(101, resnet_version=2)
def test_tensor_shapes_resnet_152_v1(self): def test_tensor_shapes_resnet_152_v1(self):
self.tensor_shapes_helper(152, version=1) self.tensor_shapes_helper(152, resnet_version=1)
def test_tensor_shapes_resnet_152_v2(self): def test_tensor_shapes_resnet_152_v2(self):
self.tensor_shapes_helper(152, version=2) self.tensor_shapes_helper(152, resnet_version=2)
def test_tensor_shapes_resnet_200_v1(self): def test_tensor_shapes_resnet_200_v1(self):
self.tensor_shapes_helper(200, version=1) self.tensor_shapes_helper(200, resnet_version=1)
def test_tensor_shapes_resnet_200_v2(self): def test_tensor_shapes_resnet_200_v2(self):
self.tensor_shapes_helper(200, version=2) self.tensor_shapes_helper(200, resnet_version=2)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v1(self): def test_tensor_shapes_resnet_18_with_gpu_v1(self):
self.tensor_shapes_helper(18, version=1, with_gpu=True) self.tensor_shapes_helper(18, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v2(self): def test_tensor_shapes_resnet_18_with_gpu_v2(self):
self.tensor_shapes_helper(18, version=2, with_gpu=True) self.tensor_shapes_helper(18, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v1(self): def test_tensor_shapes_resnet_34_with_gpu_v1(self):
self.tensor_shapes_helper(34, version=1, with_gpu=True) self.tensor_shapes_helper(34, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v2(self): def test_tensor_shapes_resnet_34_with_gpu_v2(self):
self.tensor_shapes_helper(34, version=2, with_gpu=True) self.tensor_shapes_helper(34, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v1(self): def test_tensor_shapes_resnet_50_with_gpu_v1(self):
self.tensor_shapes_helper(50, version=1, with_gpu=True) self.tensor_shapes_helper(50, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v2(self): def test_tensor_shapes_resnet_50_with_gpu_v2(self):
self.tensor_shapes_helper(50, version=2, with_gpu=True) self.tensor_shapes_helper(50, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v1(self): def test_tensor_shapes_resnet_101_with_gpu_v1(self):
self.tensor_shapes_helper(101, version=1, with_gpu=True) self.tensor_shapes_helper(101, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v2(self): def test_tensor_shapes_resnet_101_with_gpu_v2(self):
self.tensor_shapes_helper(101, version=2, with_gpu=True) self.tensor_shapes_helper(101, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v1(self): def test_tensor_shapes_resnet_152_with_gpu_v1(self):
self.tensor_shapes_helper(152, version=1, with_gpu=True) self.tensor_shapes_helper(152, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v2(self): def test_tensor_shapes_resnet_152_with_gpu_v2(self):
self.tensor_shapes_helper(152, version=2, with_gpu=True) self.tensor_shapes_helper(152, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v1(self): def test_tensor_shapes_resnet_200_with_gpu_v1(self):
self.tensor_shapes_helper(200, version=1, with_gpu=True) self.tensor_shapes_helper(200, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, resnet_version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, version, dtype): def resnet_model_fn_helper(self, mode, resnet_version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step() tf.train.create_global_step()
...@@ -199,7 +201,7 @@ class BaseTest(tf.test.TestCase): ...@@ -199,7 +201,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 50, 'resnet_size': 50,
'data_format': 'channels_last', 'data_format': 'channels_last',
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'version': version, 'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1, 'loss_scale': 128 if dtype == tf.float16 else 1,
}) })
...@@ -223,36 +225,36 @@ class BaseTest(tf.test.TestCase): ...@@ -223,36 +225,36 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32) dtype=tf.float32)
def test_resnet_model_fn_train_mode_v2(self): def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32) dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v1(self): def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32) dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self): def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32) dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self): def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=1,
dtype=tf.float32) dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=2,
dtype=tf.float32) dtype=tf.float32)
def _test_imagenetmodel_shape(self, version): def _test_imagenetmodel_shape(self, resnet_version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
model = imagenet_main.ImagenetModel( model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes, 50, data_format='channels_last', num_classes=num_classes,
version=version) resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3]) fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True) output = model(fake_input, training=True)
...@@ -260,10 +262,10 @@ class BaseTest(tf.test.TestCase): ...@@ -260,10 +262,10 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenetmodel_shape_v1(self): def test_imagenetmodel_shape_v1(self):
self._test_imagenetmodel_shape(version=1) self._test_imagenetmodel_shape(resnet_version=1)
def test_imagenetmodel_shape_v2(self): def test_imagenetmodel_shape_v2(self):
self._test_imagenetmodel_shape(version=2) self._test_imagenetmodel_shape(resnet_version=2)
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
...@@ -280,25 +282,25 @@ class BaseTest(tf.test.TestCase): ...@@ -280,25 +282,25 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1_tiny(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '18'] extra_flags=['-resnet_version', '1', '-resnet_size', '18']
) )
def test_imagenet_end_to_end_synthetic_v2_tiny(self): def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '18'] extra_flags=['-resnet_version', '2', '-resnet_size', '18']
) )
def test_imagenet_end_to_end_synthetic_v1_huge(self): def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '200'] extra_flags=['-resnet_version', '1', '-resnet_size', '200']
) )
def test_imagenet_end_to_end_synthetic_v2_huge(self): def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '200'] extra_flags=['-resnet_version', '2', '-resnet_size', '200']
) )
......
...@@ -41,14 +41,22 @@ from official.utils.testing import reference_data ...@@ -41,14 +41,22 @@ from official.utils.testing import reference_data
DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first
BATCH_SIZE = 32 BATCH_SIZE = 32
BLOCK_TESTS = [ BLOCK_TESTS = [
dict(bottleneck=True, projection=True, version=1, width=8, channels=4), dict(bottleneck=True, projection=True, resnet_version=1, width=8,
dict(bottleneck=True, projection=True, version=2, width=8, channels=4), channels=4),
dict(bottleneck=True, projection=False, version=1, width=8, channels=4), dict(bottleneck=True, projection=True, resnet_version=2, width=8,
dict(bottleneck=True, projection=False, version=2, width=8, channels=4), channels=4),
dict(bottleneck=False, projection=True, version=1, width=8, channels=4), dict(bottleneck=True, projection=False, resnet_version=1, width=8,
dict(bottleneck=False, projection=True, version=2, width=8, channels=4), channels=4),
dict(bottleneck=False, projection=False, version=1, width=8, channels=4), dict(bottleneck=True, projection=False, resnet_version=2, width=8,
dict(bottleneck=False, projection=False, version=2, width=8, channels=4), channels=4),
dict(bottleneck=False, projection=True, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=True, resnet_version=2, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=2, width=8,
channels=4),
] ]
...@@ -95,7 +103,7 @@ class BaseTest(reference_data.BaseTest): ...@@ -95,7 +103,7 @@ class BaseTest(reference_data.BaseTest):
return projection_shortcut return projection_shortcut
def _resnet_block_ops(self, test, batch_size, bottleneck, projection, def _resnet_block_ops(self, test, batch_size, bottleneck, projection,
version, width, channels): resnet_version, width, channels):
"""Test whether resnet block construction has changed. """Test whether resnet block construction has changed.
Args: Args:
...@@ -104,7 +112,7 @@ class BaseTest(reference_data.BaseTest): ...@@ -104,7 +112,7 @@ class BaseTest(reference_data.BaseTest):
batch normalization. batch normalization.
bottleneck: Whether or not to use bottleneck layers. bottleneck: Whether or not to use bottleneck layers.
projection: Whether or not to project the input. projection: Whether or not to project the input.
version: Which version of ResNet to test. resnet_version: Which version of ResNet to test.
width: The width of the fake image. width: The width of the fake image.
channels: The number of channels in the fake image. channels: The number of channels in the fake image.
""" """
...@@ -113,12 +121,12 @@ class BaseTest(reference_data.BaseTest): ...@@ -113,12 +121,12 @@ class BaseTest(reference_data.BaseTest):
batch_size, batch_size,
"bottleneck" if bottleneck else "building", "bottleneck" if bottleneck else "building",
"_projection" if projection else "", "_projection" if projection else "",
version, resnet_version,
width, width,
channels channels
) )
if version == 1: if resnet_version == 1:
block_fn = resnet_model._building_block_v1 block_fn = resnet_model._building_block_v1
if bottleneck: if bottleneck:
block_fn = resnet_model._bottleneck_block_v1 block_fn = resnet_model._bottleneck_block_v1
......
...@@ -354,7 +354,7 @@ class Model(object): ...@@ -354,7 +354,7 @@ class Model(object):
kernel_size, kernel_size,
conv_stride, first_pool_size, first_pool_stride, conv_stride, first_pool_size, first_pool_stride,
block_sizes, block_strides, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None, final_size, resnet_version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE): dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image. """Creates a model for classifying an image.
...@@ -377,8 +377,8 @@ class Model(object): ...@@ -377,8 +377,8 @@ class Model(object):
block_strides: List of integers representing the desired stride size for block_strides: List of integers representing the desired stride size for
each of the sets of block layers. Should be same length as block_sizes. each of the sets of block layers. Should be same length as block_sizes.
final_size: The expected size of the model after the second pooling. final_size: The expected size of the model after the second pooling.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network
See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
dtype: The TensorFlow dtype to use for calculations. If not specified dtype: The TensorFlow dtype to use for calculations. If not specified
...@@ -393,19 +393,19 @@ class Model(object): ...@@ -393,19 +393,19 @@ class Model(object):
data_format = ( data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
self.resnet_version = version self.resnet_version = resnet_version
if version not in (1, 2): if resnet_version not in (1, 2):
raise ValueError( raise ValueError(
'Resnet version should be 1 or 2. See README for citations.') 'Resnet version should be 1 or 2. See README for citations.')
self.bottleneck = bottleneck self.bottleneck = bottleneck
if bottleneck: if bottleneck:
if version == 1: if resnet_version == 1:
self.block_fn = _bottleneck_block_v1 self.block_fn = _bottleneck_block_v1
else: else:
self.block_fn = _bottleneck_block_v2 self.block_fn = _bottleneck_block_v2
else: else:
if version == 1: if resnet_version == 1:
self.block_fn = _building_block_v1 self.block_fn = _building_block_v1
else: else:
self.block_fn = _building_block_v2 self.block_fn = _building_block_v2
......
...@@ -155,8 +155,8 @@ def learning_rate_with_decay( ...@@ -155,8 +155,8 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_scale, loss_filter_fn=None, data_format, resnet_version, loss_scale,
dtype=resnet_model.DEFAULT_DTYPE): loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers Initializes the ResnetModel representing the model layers
...@@ -180,8 +180,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -180,8 +180,8 @@ def resnet_model_fn(features, labels, mode, model_class,
momentum: momentum term used for optimization momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network to
See README for details. Valid values: [1, 2] use. See README for details. Valid values: [1, 2]
loss_scale: The factor to scale the loss for numerical stability. A detailed loss_scale: The factor to scale the loss for numerical stability. A detailed
summary is present in the arg parser help text. summary is present in the arg parser help text.
loss_filter_fn: function that takes a string variable name and returns loss_filter_fn: function that takes a string variable name and returns
...@@ -200,7 +200,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -200,7 +200,8 @@ def resnet_model_fn(features, labels, mode, model_class,
features = tf.cast(features, dtype) features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, version=version, dtype=dtype) model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN) logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
...@@ -379,7 +380,7 @@ def resnet_main( ...@@ -379,7 +380,7 @@ def resnet_main(
'resnet_size': int(flags_obj.resnet_size), 'resnet_size': int(flags_obj.resnet_size),
'data_format': flags_obj.data_format, 'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size, 'batch_size': flags_obj.batch_size,
'version': int(flags_obj.version), 'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj), 'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj) 'dtype': flags_core.get_tf_dtype(flags_obj)
}) })
...@@ -388,7 +389,7 @@ def resnet_main( ...@@ -388,7 +389,7 @@ def resnet_main(
'batch_size': flags_obj.batch_size, 'batch_size': flags_obj.batch_size,
'dtype': flags_core.get_tf_dtype(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj),
'resnet_size': flags_obj.resnet_size, 'resnet_size': flags_obj.resnet_size,
'resnet_version': flags_obj.version, 'resnet_version': flags_obj.resnet_version,
'synthetic_data': flags_obj.use_synthetic_data, 'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs, 'train_epochs': flags_obj.train_epochs,
} }
...@@ -456,7 +457,8 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -456,7 +457,8 @@ def define_resnet_flags(resnet_size_choices=None):
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum( flags.DEFINE_enum(
name='version', short_name='rv', default='2', enum_values=['1', '2'], name='resnet_version', short_name='rv', default='2',
enum_values=['1', '2'],
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Version of ResNet. (1 or 2) See README.md for details.')) 'Version of ResNet. (1 or 2) See README.md for details.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment