Unverified Commit 565c3fa3 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #3343 from tensorflow/resnet-num-classes

Allow users to pass in num_classes to ResNet
parents 7cb653fd 75c04257
...@@ -129,8 +129,16 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -129,8 +129,16 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Running the model # Running the model
############################################################################### ###############################################################################
class Cifar10Model(resnet.Model): class Cifar10Model(resnet.Model):
def __init__(self, resnet_size, data_format=None):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
""" """
if resnet_size % 6 != 2: if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size) raise ValueError('resnet_size must be 6n + 2:', resnet_size)
...@@ -139,7 +147,7 @@ class Cifar10Model(resnet.Model): ...@@ -139,7 +147,7 @@ class Cifar10Model(resnet.Model):
super(Cifar10Model, self).__init__( super(Cifar10Model, self).__init__(
resnet_size=resnet_size, resnet_size=resnet_size,
num_classes=_NUM_CLASSES, num_classes=num_classes,
num_filters=16, num_filters=16,
kernel_size=3, kernel_size=3,
conv_stride=1, conv_stride=1,
......
...@@ -27,6 +27,9 @@ import cifar10_main ...@@ -27,6 +27,9 @@ import cifar10_main
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
_BATCH_SIZE = 128 _BATCH_SIZE = 128
_HEIGHT = 32
_WIDTH = 32
_NUM_CHANNELS = 3
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
...@@ -34,8 +37,8 @@ class BaseTest(tf.test.TestCase): ...@@ -34,8 +37,8 @@ class BaseTest(tf.test.TestCase):
def test_dataset_input_fn(self): def test_dataset_input_fn(self):
fake_data = bytearray() fake_data = bytearray()
fake_data.append(7) fake_data.append(7)
for i in range(3): for i in range(_NUM_CHANNELS):
for _ in range(1024): for _ in range(_HEIGHT * _WIDTH):
fake_data.append(i) fake_data.append(i)
_, filename = mkstemp(dir=self.get_temp_dir()) _, filename = mkstemp(dir=self.get_temp_dir())
...@@ -49,8 +52,8 @@ class BaseTest(tf.test.TestCase): ...@@ -49,8 +52,8 @@ class BaseTest(tf.test.TestCase):
lambda val: cifar10_main.parse_record(val, False)) lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next() image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertEqual(label.get_shape().as_list(), [10]) self.assertAllEqual(label.shape, (10,))
self.assertEqual(image.get_shape().as_list(), [32, 32, 3]) self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
with self.test_session() as sess: with self.test_session() as sess:
image, label = sess.run([image, label]) image, label = sess.run([image, label])
...@@ -62,7 +65,7 @@ class BaseTest(tf.test.TestCase): ...@@ -62,7 +65,7 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def input_fn(self): def input_fn(self):
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3]) features = tf.random_uniform([_BATCH_SIZE, _HEIGHT, _WIDTH, _NUM_CHANNELS])
labels = tf.random_uniform( labels = tf.random_uniform(
[_BATCH_SIZE], maxval=9, dtype=tf.int32) [_BATCH_SIZE], maxval=9, dtype=tf.int32)
return features, tf.one_hot(labels, 10) return features, tf.one_hot(labels, 10)
...@@ -104,6 +107,17 @@ class BaseTest(tf.test.TestCase): ...@@ -104,6 +107,17 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_model_fn_predict_mode(self): def test_cifar10_model_fn_predict_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
def test_cifar10model_shape(self):
batch_size = 135
num_classes = 246
model = cifar10_main.Cifar10Model(
32, data_format='channels_last', num_classes=num_classes)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -132,8 +132,16 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -132,8 +132,16 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Running the model # Running the model
############################################################################### ###############################################################################
class ImagenetModel(resnet.Model): class ImagenetModel(resnet.Model):
def __init__(self, resnet_size, data_format=None):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
""" """
# For bigger models, we want to use "bottleneck" layers # For bigger models, we want to use "bottleneck" layers
...@@ -146,7 +154,7 @@ class ImagenetModel(resnet.Model): ...@@ -146,7 +154,7 @@ class ImagenetModel(resnet.Model):
super(ImagenetModel, self).__init__( super(ImagenetModel, self).__init__(
resnet_size=resnet_size, resnet_size=resnet_size,
num_classes=_NUM_CLASSES, num_classes=num_classes,
num_filters=64, num_filters=64,
kernel_size=7, kernel_size=7,
conv_stride=2, conv_stride=2,
......
...@@ -176,6 +176,17 @@ class BaseTest(tf.test.TestCase): ...@@ -176,6 +176,17 @@ class BaseTest(tf.test.TestCase):
def test_resnet_model_fn_predict_mode(self): def test_resnet_model_fn_predict_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
def test_imagenetmodel_shape(self):
batch_size = 135
num_classes = 246
model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes)
fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment