Commit 487d18e2 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Add a data_format flag to the models + other fixes (#2583)

parent d86aa760
...@@ -35,8 +35,21 @@ parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data', ...@@ -35,8 +35,21 @@ parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data',
parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model', parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model',
help='The directory where the model will be stored.') help='The directory where the model will be stored.')
parser.add_argument('--steps', type=int, default=20000, parser.add_argument('--train_epochs', type=int, default=40,
help='Number of steps to train.') help='Number of epochs to train.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
def input_fn(mode, batch_size=1): def input_fn(mode, batch_size=1):
...@@ -89,13 +102,16 @@ def mnist_model(inputs, mode): ...@@ -89,13 +102,16 @@ def mnist_model(inputs, mode):
# Reshape X to 4-D tensor: [batch_size, width, height, channels] # Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel # MNIST images are 28x28 pixels, and have one color channel
inputs = tf.reshape(inputs, [-1, 28, 28, 1]) inputs = tf.reshape(inputs, [-1, 28, 28, 1])
data_format = 'channels_last' data_format = FLAGS.data_format
if tf.test.is_built_with_cuda(): if data_format is None:
# When running on GPU, transpose the data from channels_last (NHWC) to # When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance. # channels_first (NCHW) to improve performance.
# See https://www.tensorflow.org/performance/performance_guide#data_formats # See https://www.tensorflow.org/performance/performance_guide#data_formats
data_format = 'channels_first' data_format = ('channels_first' if tf.test.is_built_with_cuda() else
'channels_last')
if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2]) inputs = tf.transpose(inputs, [0, 3, 1, 2])
# Convolutional Layer #1 # Convolutional Layer #1
...@@ -211,9 +227,11 @@ def main(unused_argv): ...@@ -211,9 +227,11 @@ def main(unused_argv):
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
mnist_classifier.train( mnist_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size), input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),
steps=FLAGS.steps, steps=FLAGS.train_epochs * batches_per_epoch,
hooks=[logging_hook]) hooks=[logging_hook])
# Evaluate the model and print results # Evaluate the model and print results
......
...@@ -65,4 +65,5 @@ class BaseTest(tf.test.TestCase): ...@@ -65,4 +65,5 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
mnist.FLAGS = mnist.parser.parse_args()
tf.test.main() tf.test.main()
...@@ -42,11 +42,19 @@ parser.add_argument('--train_epochs', type=int, default=250, ...@@ -42,11 +42,19 @@ parser.add_argument('--train_epochs', type=int, default=250,
help='The number of epochs to train.') help='The number of epochs to train.')
parser.add_argument('--epochs_per_eval', type=int, default=10, parser.add_argument('--epochs_per_eval', type=int, default=10,
help='The number of batches to run in between evaluations.') help='The number of epochs to run in between evaluations.')
parser.add_argument('--batch_size', type=int, default=128, parser.add_argument('--batch_size', type=int, default=128,
help='The number of images per batch.') help='The number of images per batch.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_HEIGHT = 32 _HEIGHT = 32
_WIDTH = 32 _WIDTH = 32
_DEPTH = 3 _DEPTH = 3
...@@ -172,7 +180,7 @@ def cifar10_model_fn(features, labels, mode): ...@@ -172,7 +180,7 @@ def cifar10_model_fn(features, labels, mode):
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
network = resnet_model.cifar10_resnet_v2_generator( network = resnet_model.cifar10_resnet_v2_generator(
FLAGS.resnet_size, _NUM_CLASSES) FLAGS.resnet_size, _NUM_CLASSES, FLAGS.data_format)
inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
......
...@@ -53,6 +53,14 @@ parser.add_argument( ...@@ -53,6 +53,14 @@ parser.add_argument(
'--batch_size', type=int, default=32, '--batch_size', type=int, default=32,
help='Batch size for training and evaluation.') help='Batch size for training and evaluation.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_DEFAULT_IMAGE_SIZE = 224 _DEFAULT_IMAGE_SIZE = 224
_NUM_CHANNELS = 3 _NUM_CHANNELS = 3
_LABEL_CLASSES = 1001 _LABEL_CLASSES = 1001
...@@ -148,8 +156,8 @@ def resnet_model_fn(features, labels, mode): ...@@ -148,8 +156,8 @@ def resnet_model_fn(features, labels, mode):
"""Our model_fn for ResNet to be used with our Estimator.""" """Our model_fn for ResNet to be used with our Estimator."""
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
network = resnet_model.resnet_v2( network = resnet_model.imagenet_resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES) FLAGS.resnet_size, _LABEL_CLASSES, FLAGS.data_format)
logits = network( logits = network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
......
...@@ -45,7 +45,7 @@ class BaseTest(tf.test.TestCase): ...@@ -45,7 +45,7 @@ class BaseTest(tf.test.TestCase):
with graph.as_default(), self.test_session( with graph.as_default(), self.test_session(
use_gpu=with_gpu, force_gpu=with_gpu): use_gpu=with_gpu, force_gpu=with_gpu):
model = resnet_model.resnet_v2( model = resnet_model.imagenet_resnet_v2(
resnet_size, 456, resnet_size, 456,
data_format='channels_first' if with_gpu else 'channels_last') data_format='channels_first' if with_gpu else 'channels_last')
inputs = tf.random_uniform([1, 224, 224, 3]) inputs = tf.random_uniform([1, 224, 224, 3])
......
...@@ -346,7 +346,7 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes, ...@@ -346,7 +346,7 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
return model return model
def resnet_v2(resnet_size, num_classes, data_format=None): def imagenet_resnet_v2(resnet_size, num_classes, data_format=None):
"""Returns the ResNet model for a given size and number of output classes.""" """Returns the ResNet model for a given size and number of output classes."""
model_params = { model_params = {
18: {'block': building_block, 'layers': [2, 2, 2, 2]}, 18: {'block': building_block, 'layers': [2, 2, 2, 2]},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment