mnist.py 8.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
22
import sys
23
24

import tensorflow as tf
Asim Shankar's avatar
Asim Shankar committed
25
from tensorflow.examples.tutorials.mnist import input_data
26
27
28
29

parser = argparse.ArgumentParser()

# Basic model parameters.
Asim Shankar's avatar
Asim Shankar committed
30
31
32
33
34
parser.add_argument(
    '--batch_size',
    type=int,
    default=100,
    help='Number of images to process in a batch')
35

Asim Shankar's avatar
Asim Shankar committed
36
37
38
39
40
parser.add_argument(
    '--data_dir',
    type=str,
    default='/tmp/mnist_data',
    help='Path to directory containing the MNIST dataset')
41

Asim Shankar's avatar
Asim Shankar committed
42
43
44
45
46
parser.add_argument(
    '--model_dir',
    type=str,
    default='/tmp/mnist_model',
    help='The directory where the model will be stored.')
47

Asim Shankar's avatar
Asim Shankar committed
48
49
parser.add_argument(
    '--train_epochs', type=int, default=40, help='Number of epochs to train.')
50
51

parser.add_argument(
Asim Shankar's avatar
Asim Shankar committed
52
53
54
    '--data_format',
    type=str,
    default=None,
55
56
    choices=['channels_first', 'channels_last'],
    help='A flag to override the data format used in the model. channels_first '
Asim Shankar's avatar
Asim Shankar committed
57
58
59
60
    'provides a performance boost on GPU but is not always compatible '
    'with CPU. If left unspecified, the data format will be chosen '
    'automatically based on whether TensorFlow was built for CPU or GPU.')

61
62
63
64
parser.add_argument(
    '--export_dir',
    type=str,
    help='The directory where the exported SavedModel will be stored.')
65

Asim Shankar's avatar
Asim Shankar committed
66

Asim Shankar's avatar
Asim Shankar committed
67
68
69
70
def train_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
  data = input_data.read_data_sets(data_dir, one_hot=True).train
  return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
71
72


Asim Shankar's avatar
Asim Shankar committed
73
74
75
76
def eval_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
  data = input_data.read_data_sets(data_dir, one_hot=True).test
  return tf.data.Dataset.from_tensors((data.images, data.labels))
77
78


Asim Shankar's avatar
Asim Shankar committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class Model(object):
  """Class that defines a graph to recognize digits in the MNIST dataset."""

  def __init__(self, data_format):
    """Creates a model for classifying a hand-written digit.

    Args:
      data_format: Either 'channels_first' or 'channels_last'.
        'channels_first' is typically faster on GPUs while 'channels_last' is
        typically faster on CPUs. See
        https://www.tensorflow.org/performance/performance_guide#data_formats
    """
    if data_format == 'channels_first':
      self._input_shape = [-1, 1, 28, 28]
    else:
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]

    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.fc2 = tf.layers.Dense(10)
    self.dropout = tf.layers.Dropout(0.5)
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    """Add operations to classify a batch of input images.

    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.

    Returns:
      A logits Tensor with shape [<batch_size>, 10].
    """
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)


def predict_spec(model, image):
  """EstimatorSpec for predictions."""
  if isinstance(image, dict):
    image = image['image']
  logits = model(image, training=False)
134
  predictions = {
Asim Shankar's avatar
Asim Shankar committed
135
136
      'classes': tf.argmax(logits, axis=1),
      'probabilities': tf.nn.softmax(logits),
137
  }
Asim Shankar's avatar
Asim Shankar committed
138
139
140
141
142
143
  return tf.estimator.EstimatorSpec(
      mode=tf.estimator.ModeKeys.PREDICT,
      predictions=predictions,
      export_outputs={
          'classify': tf.estimator.export.PredictOutput(predictions)
      })
144
145


Asim Shankar's avatar
Asim Shankar committed
146
147
148
149
def train_spec(model, image, labels):
  """EstimatorSpec for training."""
  optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
  logits = model(image, training=True)
150
151
  loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  accuracy = tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
152
153
154
      labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
  # Name the accuracy tensor 'train_accuracy' to demonstrate the
  # LoggingTensorHook.
155
156
  tf.identity(accuracy[1], name='train_accuracy')
  tf.summary.scalar('train_accuracy', accuracy[1])
Asim Shankar's avatar
Asim Shankar committed
157
158
159
160
  return tf.estimator.EstimatorSpec(
      mode=tf.estimator.ModeKeys.TRAIN,
      loss=loss,
      train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
161

Asim Shankar's avatar
Asim Shankar committed
162
163
164
165
166

def eval_spec(model, image, labels):
  """EstimatorSpec for evaluation."""
  logits = model(image, training=False)
  loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
167
  return tf.estimator.EstimatorSpec(
Asim Shankar's avatar
Asim Shankar committed
168
      mode=tf.estimator.ModeKeys.EVAL,
169
      loss=loss,
Asim Shankar's avatar
Asim Shankar committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
      eval_metric_ops={
          'accuracy':
              tf.metrics.accuracy(
                  labels=tf.argmax(labels, axis=1),
                  predictions=tf.argmax(logits, axis=1)),
      })


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = Model(params['data_format'])
  if mode == tf.estimator.ModeKeys.PREDICT:
    return predict_spec(model, features)
  if mode == tf.estimator.ModeKeys.TRAIN:
    return train_spec(model, features, labels)
  if mode == tf.estimator.ModeKeys.EVAL:
    return eval_spec(model, features, labels)
187
188
189


def main(unused_argv):
Asim Shankar's avatar
Asim Shankar committed
190
191
192
193
  data_format = FLAGS.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
194
  mnist_classifier = tf.estimator.Estimator(
Asim Shankar's avatar
Asim Shankar committed
195
      model_fn=model_fn,
Asim Shankar's avatar
Asim Shankar committed
196
197
      model_dir=FLAGS.model_dir,
      params={
Asim Shankar's avatar
Asim Shankar committed
198
          'data_format': data_format
Asim Shankar's avatar
Asim Shankar committed
199
      })
200

201
  # Train the model
Asim Shankar's avatar
Asim Shankar committed
202
203
204
205
206
207
208
209
210
211
  def train_input_fn():
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    dataset = train_dataset(FLAGS.data_dir)
    dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
        FLAGS.train_epochs)
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

Asim Shankar's avatar
Asim Shankar committed
212
213
214
215
  # Set up training hook that logs the training accuracy every 100 steps.
  tensors_to_log = {'train_accuracy': 'train_accuracy'}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=100)
Asim Shankar's avatar
Asim Shankar committed
216
  mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
217
218

  # Evaluate the model and print results
Asim Shankar's avatar
Asim Shankar committed
219
220
221
222
  def eval_input_fn():
    return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()

  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
223
  print()
224
  print('Evaluation results:\n\t%s' % eval_results)
225

226
227
  # Export the model
  if FLAGS.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
228
229
230
231
232
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'image': tf.placeholder(tf.float32, [None, 28, 28])
    })
    mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
233

234
235
236

if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
237
238
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)