"configs/vscode:/vscode.git/clone" did not exist on "f7356f4baf1393d1c73dfbdd05944b925247b85e"
Unverified Commit 5be3c064 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Rename --version or --resnet_version (#4165)

* rename --version flag and fix tests to correctly specify version rather than verbosity

* rename version to resnet_version throughout

* fix bugs

* delint

* missed layer_test

* fix indent
parent eb0c0dfd
......@@ -140,7 +140,7 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION,
resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data.
......@@ -150,8 +150,8 @@ class Cifar10Model(resnet_model.Model):
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
Raises:
......@@ -174,7 +174,7 @@ class Cifar10Model(resnet_model.Model):
block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2],
final_size=64,
version=version,
resnet_version=resnet_version,
data_format=data_format,
dtype=dtype
)
......@@ -211,7 +211,7 @@ def cifar10_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
dtype=params['dtype']
......
......@@ -76,7 +76,7 @@ class BaseTest(tf.test.TestCase):
for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, version, dtype):
def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
......@@ -87,7 +87,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
})
......@@ -111,56 +111,57 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32)
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
resnet_version=1, dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
resnet_version=2, dtype=tf.float32)
def _test_cifar10model_shape(self, version):
def _test_cifar10model_shape(self, resnet_version):
batch_size = 135
num_classes = 246
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version)
num_classes=num_classes,
resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1)
self._test_cifar10model_shape(resnet_version=1)
def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2)
self._test_cifar10model_shape(resnet_version=2)
def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
extra_flags=['-resnet_version', '1']
)
def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
extra_flags=['-resnet_version', '2']
)
......
......@@ -197,7 +197,7 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION,
resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data.
......@@ -207,8 +207,8 @@ class ImagenetModel(resnet_model.Model):
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
"""
......@@ -232,7 +232,7 @@ class ImagenetModel(resnet_model.Model):
block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2],
final_size=final_size,
version=version,
resnet_version=resnet_version,
data_format=data_format,
dtype=dtype
)
......@@ -289,7 +289,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
dtype=params['dtype']
......
......@@ -41,7 +41,7 @@ class BaseTest(tf.test.TestCase):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
def _tensor_shapes_helper(self, resnet_size, version, dtype, with_gpu):
def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape):
"""Returns the expected dimensions depending on if a GPU is being used."""
......@@ -59,7 +59,7 @@ class BaseTest(tf.test.TestCase):
model = imagenet_main.ImagenetModel(
resnet_size=resnet_size,
data_format='channels_first' if with_gpu else 'channels_last',
version=version,
resnet_version=resnet_version,
dtype=dtype
)
inputs = tf.random_uniform([1, 224, 224, 3])
......@@ -95,97 +95,99 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size, version=version,
def tensor_shapes_helper(self, resnet_size, resnet_version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float32, with_gpu=with_gpu)
self._tensor_shapes_helper(resnet_size=resnet_size, version=version,
self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float16, with_gpu=with_gpu)
def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, version=1)
self.tensor_shapes_helper(18, resnet_version=1)
def test_tensor_shapes_resnet_18_v2(self):
self.tensor_shapes_helper(18, version=2)
self.tensor_shapes_helper(18, resnet_version=2)
def test_tensor_shapes_resnet_34_v1(self):
self.tensor_shapes_helper(34, version=1)
self.tensor_shapes_helper(34, resnet_version=1)
def test_tensor_shapes_resnet_34_v2(self):
self.tensor_shapes_helper(34, version=2)
self.tensor_shapes_helper(34, resnet_version=2)
def test_tensor_shapes_resnet_50_v1(self):
self.tensor_shapes_helper(50, version=1)
self.tensor_shapes_helper(50, resnet_version=1)
def test_tensor_shapes_resnet_50_v2(self):
self.tensor_shapes_helper(50, version=2)
self.tensor_shapes_helper(50, resnet_version=2)
def test_tensor_shapes_resnet_101_v1(self):
self.tensor_shapes_helper(101, version=1)
self.tensor_shapes_helper(101, resnet_version=1)
def test_tensor_shapes_resnet_101_v2(self):
self.tensor_shapes_helper(101, version=2)
self.tensor_shapes_helper(101, resnet_version=2)
def test_tensor_shapes_resnet_152_v1(self):
self.tensor_shapes_helper(152, version=1)
self.tensor_shapes_helper(152, resnet_version=1)
def test_tensor_shapes_resnet_152_v2(self):
self.tensor_shapes_helper(152, version=2)
self.tensor_shapes_helper(152, resnet_version=2)
def test_tensor_shapes_resnet_200_v1(self):
self.tensor_shapes_helper(200, version=1)
self.tensor_shapes_helper(200, resnet_version=1)
def test_tensor_shapes_resnet_200_v2(self):
self.tensor_shapes_helper(200, version=2)
self.tensor_shapes_helper(200, resnet_version=2)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v1(self):
self.tensor_shapes_helper(18, version=1, with_gpu=True)
self.tensor_shapes_helper(18, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v2(self):
self.tensor_shapes_helper(18, version=2, with_gpu=True)
self.tensor_shapes_helper(18, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v1(self):
self.tensor_shapes_helper(34, version=1, with_gpu=True)
self.tensor_shapes_helper(34, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v2(self):
self.tensor_shapes_helper(34, version=2, with_gpu=True)
self.tensor_shapes_helper(34, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v1(self):
self.tensor_shapes_helper(50, version=1, with_gpu=True)
self.tensor_shapes_helper(50, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v2(self):
self.tensor_shapes_helper(50, version=2, with_gpu=True)
self.tensor_shapes_helper(50, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v1(self):
self.tensor_shapes_helper(101, version=1, with_gpu=True)
self.tensor_shapes_helper(101, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v2(self):
self.tensor_shapes_helper(101, version=2, with_gpu=True)
self.tensor_shapes_helper(101, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v1(self):
self.tensor_shapes_helper(152, version=1, with_gpu=True)
self.tensor_shapes_helper(152, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v2(self):
self.tensor_shapes_helper(152, version=2, with_gpu=True)
self.tensor_shapes_helper(152, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v1(self):
self.tensor_shapes_helper(200, version=1, with_gpu=True)
self.tensor_shapes_helper(200, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True)
self.tensor_shapes_helper(200, resnet_version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, version, dtype):
def resnet_model_fn_helper(self, mode, resnet_version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step()
......@@ -199,7 +201,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 50,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
})
......@@ -223,36 +225,36 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=2,
dtype=tf.float32)
def _test_imagenetmodel_shape(self, version):
def _test_imagenetmodel_shape(self, resnet_version):
batch_size = 135
num_classes = 246
model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes,
version=version)
resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True)
......@@ -260,10 +262,10 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenetmodel_shape_v1(self):
self._test_imagenetmodel_shape(version=1)
self._test_imagenetmodel_shape(resnet_version=1)
def test_imagenetmodel_shape_v2(self):
self._test_imagenetmodel_shape(version=2)
self._test_imagenetmodel_shape(resnet_version=2)
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(
......@@ -280,25 +282,25 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '18']
extra_flags=['-resnet_version', '1', '-resnet_size', '18']
)
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '18']
extra_flags=['-resnet_version', '2', '-resnet_size', '18']
)
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '200']
extra_flags=['-resnet_version', '1', '-resnet_size', '200']
)
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '200']
extra_flags=['-resnet_version', '2', '-resnet_size', '200']
)
......
......@@ -41,14 +41,22 @@ from official.utils.testing import reference_data
DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first
BATCH_SIZE = 32
BLOCK_TESTS = [
dict(bottleneck=True, projection=True, version=1, width=8, channels=4),
dict(bottleneck=True, projection=True, version=2, width=8, channels=4),
dict(bottleneck=True, projection=False, version=1, width=8, channels=4),
dict(bottleneck=True, projection=False, version=2, width=8, channels=4),
dict(bottleneck=False, projection=True, version=1, width=8, channels=4),
dict(bottleneck=False, projection=True, version=2, width=8, channels=4),
dict(bottleneck=False, projection=False, version=1, width=8, channels=4),
dict(bottleneck=False, projection=False, version=2, width=8, channels=4),
dict(bottleneck=True, projection=True, resnet_version=1, width=8,
channels=4),
dict(bottleneck=True, projection=True, resnet_version=2, width=8,
channels=4),
dict(bottleneck=True, projection=False, resnet_version=1, width=8,
channels=4),
dict(bottleneck=True, projection=False, resnet_version=2, width=8,
channels=4),
dict(bottleneck=False, projection=True, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=True, resnet_version=2, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=2, width=8,
channels=4),
]
......@@ -95,7 +103,7 @@ class BaseTest(reference_data.BaseTest):
return projection_shortcut
def _resnet_block_ops(self, test, batch_size, bottleneck, projection,
version, width, channels):
resnet_version, width, channels):
"""Test whether resnet block construction has changed.
Args:
......@@ -104,7 +112,7 @@ class BaseTest(reference_data.BaseTest):
batch normalization.
bottleneck: Whether or not to use bottleneck layers.
projection: Whether or not to project the input.
version: Which version of ResNet to test.
resnet_version: Which version of ResNet to test.
width: The width of the fake image.
channels: The number of channels in the fake image.
"""
......@@ -113,12 +121,12 @@ class BaseTest(reference_data.BaseTest):
batch_size,
"bottleneck" if bottleneck else "building",
"_projection" if projection else "",
version,
resnet_version,
width,
channels
)
if version == 1:
if resnet_version == 1:
block_fn = resnet_model._building_block_v1
if bottleneck:
block_fn = resnet_model._bottleneck_block_v1
......
......@@ -354,7 +354,7 @@ class Model(object):
kernel_size,
conv_stride, first_pool_size, first_pool_stride,
block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None,
final_size, resnet_version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image.
......@@ -377,8 +377,8 @@ class Model(object):
block_strides: List of integers representing the desired stride size for
each of the sets of block layers. Should be same length as block_sizes.
final_size: The expected size of the model after the second pooling.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
dtype: The TensorFlow dtype to use for calculations. If not specified
......@@ -393,19 +393,19 @@ class Model(object):
data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
self.resnet_version = version
if version not in (1, 2):
self.resnet_version = resnet_version
if resnet_version not in (1, 2):
raise ValueError(
'Resnet version should be 1 or 2. See README for citations.')
self.bottleneck = bottleneck
if bottleneck:
if version == 1:
if resnet_version == 1:
self.block_fn = _bottleneck_block_v1
else:
self.block_fn = _bottleneck_block_v2
else:
if version == 1:
if resnet_version == 1:
self.block_fn = _building_block_v1
else:
self.block_fn = _building_block_v2
......
......@@ -155,8 +155,8 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_scale, loss_filter_fn=None,
dtype=resnet_model.DEFAULT_DTYPE):
data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
......@@ -180,8 +180,8 @@ def resnet_model_fn(features, labels, mode, model_class,
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
resnet_version: Integer representing which version of the ResNet network to
use. See README for details. Valid values: [1, 2]
loss_scale: The factor to scale the loss for numerical stability. A detailed
summary is present in the arg parser help text.
loss_filter_fn: function that takes a string variable name and returns
......@@ -200,7 +200,8 @@ def resnet_model_fn(features, labels, mode, model_class,
features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, version=version, dtype=dtype)
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
......@@ -379,7 +380,7 @@ def resnet_main(
'resnet_size': int(flags_obj.resnet_size),
'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size,
'version': int(flags_obj.version),
'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj)
})
......@@ -388,7 +389,7 @@ def resnet_main(
'batch_size': flags_obj.batch_size,
'dtype': flags_core.get_tf_dtype(flags_obj),
'resnet_size': flags_obj.resnet_size,
'resnet_version': flags_obj.version,
'resnet_version': flags_obj.resnet_version,
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}
......@@ -456,7 +457,8 @@ def define_resnet_flags(resnet_size_choices=None):
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum(
name='version', short_name='rv', default='2', enum_values=['1', '2'],
name='resnet_version', short_name='rv', default='2',
enum_values=['1', '2'],
help=flags_core.help_wrap(
'Version of ResNet. (1 or 2) See README.md for details.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment