Commit afd3e677 authored by Asim Shankar's avatar Asim Shankar
Browse files

[mnist]: Tweaks

- Remove `convert_to_records.py` and instead create `tf.data.Dataset`
  objects directly from the numpy arrays.
- Format the Google Python Style (https://github.com/google/yapf/)
parent 5a5d3305
...@@ -12,13 +12,6 @@ APIs. ...@@ -12,13 +12,6 @@ APIs.
## Setup ## Setup
To begin, you'll simply need the latest version of TensorFlow installed. To begin, you'll simply need the latest version of TensorFlow installed.
First convert the MNIST data to TFRecord file format by running the following:
```
python convert_to_records.py
```
Then to train the model, run the following: Then to train the model, run the following:
``` ```
......
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts MNIST data to TFRecords file format with Example protos.
To read about optimizations that can be applied to the input preprocessing
stage, see: https://www.tensorflow.org/performance/performance_guide#input_pipeline_optimization.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
parser = argparse.ArgumentParser()
parser.add_argument('--directory', type=str, default='/tmp/mnist_data',
help='Directory to download data files and write the '
'converted result.')
parser.add_argument('--validation_size', type=int, default=0,
help='Number of examples to separate from the training '
'data for the validation set.')
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(dataset, name, directory):
"""Converts a dataset to TFRecords."""
images = dataset.images
labels = dataset.labels
num_examples = dataset.num_examples
if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
def main(unused_argv):
# Get the data.
datasets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(datasets.train, 'train', FLAGS.directory)
convert_to(datasets.validation, 'validation', FLAGS.directory)
convert_to(datasets.test, 'test', FLAGS.directory)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -22,75 +22,53 @@ import os ...@@ -22,75 +22,53 @@ import os
import sys import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Basic model parameters. # Basic model parameters.
parser.add_argument('--batch_size', type=int, default=100, parser.add_argument(
help='Number of images to process in a batch') '--batch_size',
type=int,
default=100,
help='Number of images to process in a batch')
parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data', parser.add_argument(
help='Path to the MNIST data directory.') '--data_dir',
type=str,
default='/tmp/mnist_data',
help='Path to directory containing the MNIST dataset')
parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model', parser.add_argument(
help='The directory where the model will be stored.') '--model_dir',
type=str,
default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument('--train_epochs', type=int, default=40, parser.add_argument(
help='Number of epochs to train.') '--train_epochs', type=int, default=40, help='Number of epochs to train.')
parser.add_argument( parser.add_argument(
'--data_format', type=str, default=None, '--data_format',
type=str,
default=None,
choices=['channels_first', 'channels_last'], choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first ' help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible ' 'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen ' 'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.') 'automatically based on whether TensorFlow was built for CPU or GPU.')
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the tf.data input pipeline."""
def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)
dataset = tf.data.TFRecordDataset([filename])
# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
# a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
# We call repeat after shuffling, rather than before, to prevent separate def train_dataset(data_dir):
# epochs from blending together. """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
dataset = dataset.repeat(num_epochs) data = input_data.read_data_sets(data_dir, one_hot=True).train
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(example_parser).prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels def eval_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
data = input_data.read_data_sets(data_dir, one_hot=True).test
return tf.data.Dataset.from_tensors((data.images, data.labels))
def mnist_model(inputs, mode, data_format): def mnist_model(inputs, mode, data_format):
...@@ -104,8 +82,8 @@ def mnist_model(inputs, mode, data_format): ...@@ -104,8 +82,8 @@ def mnist_model(inputs, mode, data_format):
# When running on GPU, transpose the data from channels_last (NHWC) to # When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance. # channels_first (NCHW) to improve performance.
# See https://www.tensorflow.org/performance/performance_guide#data_formats # See https://www.tensorflow.org/performance/performance_guide#data_formats
data_format = ('channels_first' if tf.test.is_built_with_cuda() else data_format = ('channels_first'
'channels_last') if tf.test.is_built_with_cuda() else 'channels_last')
if data_format == 'channels_first': if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2]) inputs = tf.transpose(inputs, [0, 3, 1, 2])
...@@ -127,8 +105,8 @@ def mnist_model(inputs, mode, data_format): ...@@ -127,8 +105,8 @@ def mnist_model(inputs, mode, data_format):
# First max pooling layer with a 2x2 filter and stride of 2 # First max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 28, 28, 32] # Input Tensor Shape: [batch_size, 28, 28, 32]
# Output Tensor Shape: [batch_size, 14, 14, 32] # Output Tensor Shape: [batch_size, 14, 14, 32]
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2, pool1 = tf.layers.max_pooling2d(
data_format=data_format) inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format)
# Convolutional Layer #2 # Convolutional Layer #2
# Computes 64 features using a 5x5 filter. # Computes 64 features using a 5x5 filter.
...@@ -147,8 +125,8 @@ def mnist_model(inputs, mode, data_format): ...@@ -147,8 +125,8 @@ def mnist_model(inputs, mode, data_format):
# Second max pooling layer with a 2x2 filter and stride of 2 # Second max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 14, 14, 64] # Input Tensor Shape: [batch_size, 14, 14, 64]
# Output Tensor Shape: [batch_size, 7, 7, 64] # Output Tensor Shape: [batch_size, 7, 7, 64]
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2, pool2 = tf.layers.max_pooling2d(
data_format=data_format) inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format)
# Flatten tensor into a batch of vectors # Flatten tensor into a batch of vectors
# Input Tensor Shape: [batch_size, 7, 7, 64] # Input Tensor Shape: [batch_size, 7, 7, 64]
...@@ -159,8 +137,7 @@ def mnist_model(inputs, mode, data_format): ...@@ -159,8 +137,7 @@ def mnist_model(inputs, mode, data_format):
# Densely connected layer with 1024 neurons # Densely connected layer with 1024 neurons
# Input Tensor Shape: [batch_size, 7 * 7 * 64] # Input Tensor Shape: [batch_size, 7 * 7 * 64]
# Output Tensor Shape: [batch_size, 1024] # Output Tensor Shape: [batch_size, 1024]
dense = tf.layers.dense(inputs=pool2_flat, units=1024, dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
activation=tf.nn.relu)
# Add dropout operation; 0.6 probability that element will be kept # Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout( dropout = tf.layers.dropout(
...@@ -211,34 +188,37 @@ def mnist_model_fn(features, labels, mode, params): ...@@ -211,34 +188,37 @@ def mnist_model_fn(features, labels, mode, params):
def main(unused_argv): def main(unused_argv):
# Make sure that training and testing data have been converted.
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), (
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')
# Create the Estimator # Create the Estimator
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir, model_fn=mnist_model_fn,
params={'data_format': FLAGS.data_format}) model_dir=FLAGS.model_dir,
params={
'data_format': FLAGS.data_format
})
# Set up training hook that logs the training accuracy every 100 steps. # Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = { tensors_to_log = {'train_accuracy': 'train_accuracy'}
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
# Train the model # Train the model
mnist_classifier.train( def train_input_fn():
input_fn=lambda: input_fn( # When choosing shuffle buffer sizes, larger sizes result in better
True, train_file, FLAGS.batch_size, FLAGS.train_epochs), # randomness, while smaller sizes use less memory. MNIST is a small
hooks=[logging_hook]) # enough dataset that we can easily shuffle the full epoch.
dataset = train_dataset(FLAGS.data_dir)
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = dataset.make_one_shot_iterator().get_next()
return (images, labels)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
# Evaluate the model and print results # Evaluate the model and print results
eval_results = mnist_classifier.evaluate( def eval_input_fn():
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size)) return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print() print()
print('Evaluation results:\n\t%s' % eval_results) print('Evaluation results:\n\t%s' % eval_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment