Unverified Commit 4702de29 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Use FLAGS in main functions only + Updates to shuffling (#2601)

parent edcd29f2
......@@ -50,11 +50,11 @@ def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(data_set, name):
def convert_to(dataset, name, directory):
"""Converts a dataset to TFRecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples
images = dataset.images
labels = dataset.labels
num_examples = dataset.num_examples
if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
......@@ -63,7 +63,7 @@ def convert_to(data_set, name):
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
filename = os.path.join(directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
......@@ -80,15 +80,15 @@ def convert_to(data_set, name):
def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
datasets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')
convert_to(datasets.train, 'train', FLAGS.directory)
convert_to(datasets.validation, 'validation', FLAGS.directory)
convert_to(datasets.test, 'test', FLAGS.directory)
if __name__ == '__main__':
......
......@@ -52,7 +52,7 @@ _NUM_IMAGES = {
}
def input_fn(mode, batch_size=1):
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the contrib.data input pipeline."""
def example_parser(serialized_example):
......@@ -71,21 +71,15 @@ def input_fn(mode, batch_size=1):
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)
if mode == tf.estimator.ModeKeys.TRAIN:
tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
else:
assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
tfrecords_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
dataset = tf.contrib.data.TFRecordDataset([filename])
assert tf.gfile.Exists(tfrecords_file), (
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
# a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
dataset = tf.contrib.data.TFRecordDataset([tfrecords_file])
# For training, repeat the dataset forever
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.repeat()
dataset = dataset.repeat(num_epochs)
# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(
......@@ -96,13 +90,12 @@ def input_fn(mode, batch_size=1):
return images, labels
def mnist_model(inputs, mode):
def mnist_model(inputs, mode, data_format):
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
data_format = FLAGS.data_format
if data_format is None:
# When running on GPU, transpose the data from channels_last (NHWC) to
......@@ -177,9 +170,9 @@ def mnist_model(inputs, mode):
return logits
def mnist_model_fn(features, labels, mode):
def mnist_model_fn(features, labels, mode, params):
"""Model function for MNIST."""
logits = mnist_model(features, mode)
logits = mnist_model(features, mode, params['data_format'])
predictions = {
'classes': tf.argmax(input=logits, axis=1),
......@@ -215,30 +208,36 @@ def mnist_model_fn(features, labels, mode):
def main(unused_argv):
# Make sure that training and testing data have been converted.
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), (
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')
# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir)
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir,
params={'data_format': FLAGS.data_format})
# Train the model
# Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = {
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
# Train the model
mnist_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),
steps=FLAGS.train_epochs * batches_per_epoch,
input_fn=lambda: input_fn(
True, train_file, FLAGS.batch_size, FLAGS.train_epochs),
hooks=[logging_hook])
# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL))
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size))
print()
print('Evaluation results:\n %s' % eval_results)
print('Evaluation results:\n\t%s' % eval_results)
if __name__ == '__main__':
......
......@@ -34,7 +34,8 @@ class BaseTest(tf.test.TestCase):
def mnist_model_fn_helper(self, mode):
features, labels = self.input_fn()
image_count = features.shape[0]
spec = mnist.mnist_model_fn(features, labels, mode)
spec = mnist.mnist_model_fn(
features, labels, mode, {'data_format': 'channels_last'})
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
......@@ -65,5 +66,4 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__':
mnist.FLAGS = mnist.parser.parse_args()
tf.test.main()
......@@ -71,6 +71,8 @@ _NUM_IMAGES = {
'validation': 10000,
}
_SHUFFLE_BUFFER = 20000
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
......@@ -78,9 +80,9 @@ def record_dataset(filenames):
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(is_training):
def get_filenames(is_training, data_dir):
"""Returns a list of filenames."""
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
assert os.path.exists(data_dir), (
'Run cifar10_download_and_extract.py first to download and extract the '
......@@ -135,7 +137,7 @@ def train_preprocess_fn(image, label):
return image, label
def input_fn(is_training, num_epochs=1):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
Args:
......@@ -145,42 +147,41 @@ def input_fn(is_training, num_epochs=1):
Returns:
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training))
dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser, num_threads=1,
output_buffer_size=2 * FLAGS.batch_size)
output_buffer_size=2 * batch_size)
# For training, preprocess the image and shuffle.
if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1,
output_buffer_size=2 * FLAGS.batch_size)
output_buffer_size=2 * batch_size)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
buffer_size = int(0.4 * _NUM_IMAGES['train'])
dataset = dataset.shuffle(buffer_size=buffer_size)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label),
num_threads=1,
output_buffer_size=2 * FLAGS.batch_size)
output_buffer_size=2 * batch_size)
dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
iterator = dataset.batch(batch_size).make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
def cifar10_model_fn(features, labels, mode):
def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10."""
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.cifar10_resnet_v2_generator(
FLAGS.resnet_size, _NUM_CLASSES, FLAGS.data_format)
params['resnet_size'], _NUM_CLASSES, params['data_format'])
inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
......@@ -208,8 +209,8 @@ def cifar10_model_fn(features, labels, mode):
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size
# is 128, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 128
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
initial_learning_rate = 0.1 * params['batch_size'] / 128
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
......@@ -256,7 +257,12 @@ def main(unused_argv):
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
cifar_classifier = tf.estimator.Estimator(
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config)
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config,
params={
'resnet_size': FLAGS.resnet_size,
'data_format': FLAGS.data_format,
'batch_size': FLAGS.batch_size,
})
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
tensors_to_log = {
......@@ -270,12 +276,12 @@ def main(unused_argv):
cifar_classifier.train(
input_fn=lambda: input_fn(
is_training=True, num_epochs=FLAGS.epochs_per_eval),
True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval),
hooks=[logging_hook])
# Evaluate the model and print results
eval_results = cifar_classifier.evaluate(
input_fn=lambda: input_fn(is_training=False))
input_fn=lambda: input_fn(False, FLAGS.data_dir, FLAGS.batch_size))
print(eval_results)
......
......@@ -26,6 +26,8 @@ import cifar10_main
tf.logging.set_verbosity(tf.logging.ERROR)
_BATCH_SIZE = 128
class BaseTest(tf.test.TestCase):
......@@ -58,20 +60,25 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(pixel, np.array([0, 1, 2]))
def input_fn(self):
features = tf.random_uniform([FLAGS.batch_size, 32, 32, 3])
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3])
labels = tf.random_uniform(
[FLAGS.batch_size], maxval=9, dtype=tf.int32)
[_BATCH_SIZE], maxval=9, dtype=tf.int32)
return features, tf.one_hot(labels, 10)
def cifar10_model_fn_helper(self, mode):
features, labels = self.input_fn()
spec = cifar10_main.cifar10_model_fn(features, labels, mode)
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
})
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(FLAGS.batch_size, 10))
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
......@@ -97,6 +104,4 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__':
cifar10_main.FLAGS = cifar10_main.parser.parse_args()
FLAGS = cifar10_main.FLAGS
tf.test.main()
......@@ -73,16 +73,18 @@ _NUM_IMAGES = {
'validation': 50000,
}
_SHUFFLE_BUFFER = 1500
def filenames(is_training):
def filenames(is_training, data_dir):
"""Return filenames for dataset."""
if is_training:
return [
os.path.join(FLAGS.data_dir, 'train-%05d-of-01024' % i)
os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(0, 1024)]
else:
return [
os.path.join(FLAGS.data_dir, 'validation-%05d-of-00128' % i)
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
for i in range(0, 128)]
......@@ -129,9 +131,11 @@ def dataset_parser(value, is_training):
return image, tf.one_hot(label, _LABEL_CLASSES)
def input_fn(is_training, num_epochs=1):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames(is_training))
dataset = tf.contrib.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir))
if is_training:
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
......@@ -141,23 +145,24 @@ def input_fn(is_training, num_epochs=1):
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5,
output_buffer_size=FLAGS.batch_size)
output_buffer_size=batch_size)
if is_training:
buffer_size = 1250 + 2 * FLAGS.batch_size
dataset = dataset.shuffle(buffer_size=buffer_size)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
iterator = dataset.batch(batch_size).make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
def resnet_model_fn(features, labels, mode):
def resnet_model_fn(features, labels, mode, params):
"""Our model_fn for ResNet to be used with our Estimator."""
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.imagenet_resnet_v2(
FLAGS.resnet_size, _LABEL_CLASSES, FLAGS.data_format)
params['resnet_size'], _LABEL_CLASSES, params['data_format'])
logits = network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
......@@ -185,8 +190,8 @@ def resnet_model_fn(features, labels, mode):
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 256
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
initial_learning_rate = 0.1 * params['batch_size'] / 256
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs.
......@@ -235,7 +240,12 @@ def main(unused_argv):
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
resnet_classifier = tf.estimator.Estimator(
model_fn=resnet_model_fn, model_dir=FLAGS.model_dir, config=run_config)
model_fn=resnet_model_fn, model_dir=FLAGS.model_dir, config=run_config,
params={
'resnet_size': FLAGS.resnet_size,
'data_format': FLAGS.data_format,
'batch_size': FLAGS.batch_size,
})
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
tensors_to_log = {
......@@ -250,12 +260,12 @@ def main(unused_argv):
print('Starting a training cycle.')
resnet_classifier.train(
input_fn=lambda: input_fn(
is_training=True, num_epochs=FLAGS.epochs_per_eval),
True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval),
hooks=[logging_hook])
print('Starting to evaluate.')
eval_results = resnet_classifier.evaluate(
input_fn=lambda: input_fn(is_training=False))
input_fn=lambda: input_fn(False, FLAGS.data_dir, FLAGS.batch_size))
print(eval_results)
......
......@@ -26,6 +26,7 @@ import resnet_model
tf.logging.set_verbosity(tf.logging.ERROR)
_BATCH_SIZE = 32
_LABEL_CLASSES = 1001
......@@ -125,10 +126,10 @@ class BaseTest(tf.test.TestCase):
def input_fn(self):
"""Provides random features and labels."""
features = tf.random_uniform([FLAGS.batch_size, 224, 224, 3])
features = tf.random_uniform([_BATCH_SIZE, 224, 224, 3])
labels = tf.one_hot(
tf.random_uniform(
[FLAGS.batch_size], maxval=_LABEL_CLASSES - 1,
[_BATCH_SIZE], maxval=_LABEL_CLASSES - 1,
dtype=tf.int32),
_LABEL_CLASSES)
......@@ -139,13 +140,18 @@ class BaseTest(tf.test.TestCase):
tf.train.create_global_step()
features, labels = self.input_fn()
spec = imagenet_main.resnet_model_fn(features, labels, mode)
spec = imagenet_main.resnet_model_fn(
features, labels, mode, {
'resnet_size': 50,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
})
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(FLAGS.batch_size, _LABEL_CLASSES))
(_BATCH_SIZE, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
......@@ -171,6 +177,4 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__':
imagenet_main.FLAGS = imagenet_main.parser.parse_args()
FLAGS = imagenet_main.FLAGS
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment