mnist.py 7.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
22
import sys
23
24

import tensorflow as tf
Asim Shankar's avatar
Asim Shankar committed
25
from tensorflow.examples.tutorials.mnist import input_data
26
27
28
29

parser = argparse.ArgumentParser()

# Basic model parameters.
Asim Shankar's avatar
Asim Shankar committed
30
31
32
33
34
parser.add_argument(
    '--batch_size',
    type=int,
    default=100,
    help='Number of images to process in a batch')
35

Asim Shankar's avatar
Asim Shankar committed
36
37
38
39
40
parser.add_argument(
    '--data_dir',
    type=str,
    default='/tmp/mnist_data',
    help='Path to directory containing the MNIST dataset')
41

Asim Shankar's avatar
Asim Shankar committed
42
43
44
45
46
parser.add_argument(
    '--model_dir',
    type=str,
    default='/tmp/mnist_model',
    help='The directory where the model will be stored.')
47

Asim Shankar's avatar
Asim Shankar committed
48
49
parser.add_argument(
    '--train_epochs', type=int, default=40, help='Number of epochs to train.')
50
51

parser.add_argument(
Asim Shankar's avatar
Asim Shankar committed
52
53
54
    '--data_format',
    type=str,
    default=None,
55
56
    choices=['channels_first', 'channels_last'],
    help='A flag to override the data format used in the model. channels_first '
Asim Shankar's avatar
Asim Shankar committed
57
58
59
60
    'provides a performance boost on GPU but is not always compatible '
    'with CPU. If left unspecified, the data format will be chosen '
    'automatically based on whether TensorFlow was built for CPU or GPU.')

61
62
63
64
parser.add_argument(
    '--export_dir',
    type=str,
    help='The directory where the exported SavedModel will be stored.')
65

Asim Shankar's avatar
Asim Shankar committed
66

Asim Shankar's avatar
Asim Shankar committed
67
68
69
70
def train_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
  data = input_data.read_data_sets(data_dir, one_hot=True).train
  return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
71
72


Asim Shankar's avatar
Asim Shankar committed
73
74
75
76
def eval_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
  data = input_data.read_data_sets(data_dir, one_hot=True).test
  return tf.data.Dataset.from_tensors((data.images, data.labels))
77
78


Asim Shankar's avatar
Asim Shankar committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class Model(object):
  """Class that defines a graph to recognize digits in the MNIST dataset."""

  def __init__(self, data_format):
    """Creates a model for classifying a hand-written digit.

    Args:
      data_format: Either 'channels_first' or 'channels_last'.
        'channels_first' is typically faster on GPUs while 'channels_last' is
        typically faster on CPUs. See
        https://www.tensorflow.org/performance/performance_guide#data_formats
    """
    if data_format == 'channels_first':
      self._input_shape = [-1, 1, 28, 28]
    else:
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]

    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.fc2 = tf.layers.Dense(10)
103
    self.dropout = tf.layers.Dropout(0.4)
Asim Shankar's avatar
Asim Shankar committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    """Add operations to classify a batch of input images.

    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.

    Returns:
      A logits Tensor with shape [<batch_size>, 10].
    """
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = Model(params['data_format'])
132
133
134
135
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
136
  if mode == tf.estimator.ModeKeys.PREDICT:
137
138
139
140
141
142
143
144
145
146
147
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
148
  if mode == tf.estimator.ModeKeys.TRAIN:
149
150
151
152
153
154
155
156
157
158
159
160
161
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    logits = model(image, training=True)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    accuracy = tf.metrics.accuracy(
        labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
    # Name the accuracy tensor 'train_accuracy' to demonstrate the
    # LoggingTensorHook.
    tf.identity(accuracy[1], name='train_accuracy')
    tf.summary.scalar('train_accuracy', accuracy[1])
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
162
  if mode == tf.estimator.ModeKeys.EVAL:
163
164
165
166
167
168
169
170
171
172
173
    logits = model(image, training=False)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
                    labels=tf.argmax(labels, axis=1),
                    predictions=tf.argmax(logits, axis=1)),
        })
174
175
176


def main(unused_argv):
Asim Shankar's avatar
Asim Shankar committed
177
178
179
180
  data_format = FLAGS.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
181
  mnist_classifier = tf.estimator.Estimator(
Asim Shankar's avatar
Asim Shankar committed
182
      model_fn=model_fn,
Asim Shankar's avatar
Asim Shankar committed
183
184
      model_dir=FLAGS.model_dir,
      params={
Asim Shankar's avatar
Asim Shankar committed
185
          'data_format': data_format
Asim Shankar's avatar
Asim Shankar committed
186
      })
187

188
  # Train the model
Asim Shankar's avatar
Asim Shankar committed
189
190
191
192
193
194
195
196
197
198
  def train_input_fn():
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    dataset = train_dataset(FLAGS.data_dir)
    dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
        FLAGS.train_epochs)
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

Asim Shankar's avatar
Asim Shankar committed
199
200
201
202
  # Set up training hook that logs the training accuracy every 100 steps.
  tensors_to_log = {'train_accuracy': 'train_accuracy'}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=100)
Asim Shankar's avatar
Asim Shankar committed
203
  mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
204
205

  # Evaluate the model and print results
Asim Shankar's avatar
Asim Shankar committed
206
207
208
209
  def eval_input_fn():
    return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()

  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
210
  print()
211
  print('Evaluation results:\n\t%s' % eval_results)
212

213
214
  # Export the model
  if FLAGS.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
215
216
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
217
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
218
219
    })
    mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
220

221
222
223

if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
224
225
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)