Commit c20eb5ff authored by Neal Wu's avatar Neal Wu
Browse files

Add unit tests and docstrings

parent 4d053deb
......@@ -131,7 +131,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
class Cifar10Model(resnet.Model):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
"""These are the parameters that work for CIFAR-10 data."""
"""These are the parameters that work for CIFAR-10 data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to different datasets.
"""
if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size)
......
......@@ -27,6 +27,9 @@ import cifar10_main
tf.logging.set_verbosity(tf.logging.ERROR)
_BATCH_SIZE = 128
_HEIGHT = 32
_WIDTH = 32
_NUM_CHANNELS = 3
class BaseTest(tf.test.TestCase):
......@@ -34,8 +37,8 @@ class BaseTest(tf.test.TestCase):
def test_dataset_input_fn(self):
fake_data = bytearray()
fake_data.append(7)
for i in range(3):
for _ in range(1024):
for i in range(_NUM_CHANNELS):
for _ in range(_HEIGHT * _WIDTH):
fake_data.append(i)
_, filename = mkstemp(dir=self.get_temp_dir())
......@@ -49,8 +52,8 @@ class BaseTest(tf.test.TestCase):
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertEqual(label.get_shape().as_list(), [10])
self.assertEqual(image.get_shape().as_list(), [32, 32, 3])
self.assertAllEqual(label.shape, (10,))
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
with self.test_session() as sess:
image, label = sess.run([image, label])
......@@ -62,7 +65,7 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def input_fn(self):
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3])
features = tf.random_uniform([_BATCH_SIZE, _HEIGHT, _WIDTH, _NUM_CHANNELS])
labels = tf.random_uniform(
[_BATCH_SIZE], maxval=9, dtype=tf.int32)
return features, tf.one_hot(labels, 10)
......@@ -104,6 +107,18 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_model_fn_predict_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
def test_cifar10model_shape(self):
batch_size = 135
num_classes = 246
model = cifar10_main.Cifar10Model(
32, data_format='channels_last', num_classes=num_classes)
fake_input = tf.constant(
0.0, shape=[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
if __name__ == '__main__':
tf.test.main()
......@@ -134,7 +134,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
class ImagenetModel(resnet.Model):
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
"""These are the parameters that work for Imagenet data."""
"""These are the parameters that work for Imagenet data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to different datasets.
"""
# For bigger models, we want to use "bottleneck" layers
if resnet_size < 50:
......
......@@ -176,6 +176,18 @@ class BaseTest(tf.test.TestCase):
def test_resnet_model_fn_predict_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
def test_imagenetmodel_shape(self):
batch_size = 135
num_classes = 246
model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes)
fake_input = tf.constant(
0.0, shape=[batch_size, 224, 224, 3])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment