Commit 487d18e2 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Add a data_format flag to the models + other fixes (#2583)

parent d86aa760
......@@ -35,8 +35,21 @@ parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data',
parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument('--steps', type=int, default=20000,
help='Number of steps to train.')
parser.add_argument('--train_epochs', type=int, default=40,
help='Number of epochs to train.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
def input_fn(mode, batch_size=1):
......@@ -89,13 +102,16 @@ def mnist_model(inputs, mode):
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
data_format = 'channels_last'
data_format = FLAGS.data_format
if tf.test.is_built_with_cuda():
if data_format is None:
# When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance.
# See https://www.tensorflow.org/performance/performance_guide#data_formats
data_format = 'channels_first'
data_format = ('channels_first' if tf.test.is_built_with_cuda() else
'channels_last')
if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
# Convolutional Layer #1
......@@ -211,9 +227,11 @@ def main(unused_argv):
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
mnist_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),
steps=FLAGS.steps,
steps=FLAGS.train_epochs * batches_per_epoch,
hooks=[logging_hook])
# Evaluate the model and print results
......
......@@ -65,4 +65,5 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__':
mnist.FLAGS = mnist.parser.parse_args()
tf.test.main()
......@@ -42,11 +42,19 @@ parser.add_argument('--train_epochs', type=int, default=250,
help='The number of epochs to train.')
parser.add_argument('--epochs_per_eval', type=int, default=10,
help='The number of batches to run in between evaluations.')
help='The number of epochs to run in between evaluations.')
parser.add_argument('--batch_size', type=int, default=128,
help='The number of images per batch.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_HEIGHT = 32
_WIDTH = 32
_DEPTH = 3
......@@ -172,7 +180,7 @@ def cifar10_model_fn(features, labels, mode):
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.cifar10_resnet_v2_generator(
FLAGS.resnet_size, _NUM_CLASSES)
FLAGS.resnet_size, _NUM_CLASSES, FLAGS.data_format)
inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
......
......@@ -53,6 +53,14 @@ parser.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_DEFAULT_IMAGE_SIZE = 224
_NUM_CHANNELS = 3
_LABEL_CLASSES = 1001
......@@ -148,8 +156,8 @@ def resnet_model_fn(features, labels, mode):
"""Our model_fn for ResNet to be used with our Estimator."""
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)
network = resnet_model.imagenet_resnet_v2(
FLAGS.resnet_size, _LABEL_CLASSES, FLAGS.data_format)
logits = network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
......
......@@ -45,7 +45,7 @@ class BaseTest(tf.test.TestCase):
with graph.as_default(), self.test_session(
use_gpu=with_gpu, force_gpu=with_gpu):
model = resnet_model.resnet_v2(
model = resnet_model.imagenet_resnet_v2(
resnet_size, 456,
data_format='channels_first' if with_gpu else 'channels_last')
inputs = tf.random_uniform([1, 224, 224, 3])
......
......@@ -346,7 +346,7 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
return model
def resnet_v2(resnet_size, num_classes, data_format=None):
def imagenet_resnet_v2(resnet_size, num_classes, data_format=None):
"""Returns the ResNet model for a given size and number of output classes."""
model_params = {
18: {'block': building_block, 'layers': [2, 2, 2, 2]},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment