mnist.py 7.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
22
import sys
23
24

import tensorflow as tf
Asim Shankar's avatar
Asim Shankar committed
25
from tensorflow.examples.tutorials.mnist import input_data
26

Asim Shankar's avatar
Asim Shankar committed
27

Asim Shankar's avatar
Asim Shankar committed
28
29
30
31
def train_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
  data = input_data.read_data_sets(data_dir, one_hot=True).train
  return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
32
33


Asim Shankar's avatar
Asim Shankar committed
34
35
36
37
def eval_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
  data = input_data.read_data_sets(data_dir, one_hot=True).test
  return tf.data.Dataset.from_tensors((data.images, data.labels))
38
39


Asim Shankar's avatar
Asim Shankar committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Model(object):
  """Class that defines a graph to recognize digits in the MNIST dataset."""

  def __init__(self, data_format):
    """Creates a model for classifying a hand-written digit.

    Args:
      data_format: Either 'channels_first' or 'channels_last'.
        'channels_first' is typically faster on GPUs while 'channels_last' is
        typically faster on CPUs. See
        https://www.tensorflow.org/performance/performance_guide#data_formats
    """
    if data_format == 'channels_first':
      self._input_shape = [-1, 1, 28, 28]
    else:
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]

    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.fc2 = tf.layers.Dense(10)
64
    self.dropout = tf.layers.Dropout(0.4)
Asim Shankar's avatar
Asim Shankar committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    """Add operations to classify a batch of input images.

    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.

    Returns:
      A logits Tensor with shape [<batch_size>, 10].
    """
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = Model(params['data_format'])
93
94
95
96
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
97
  if mode == tf.estimator.ModeKeys.PREDICT:
98
99
100
101
102
103
104
105
106
107
108
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
109
  if mode == tf.estimator.ModeKeys.TRAIN:
110
111
112
113
114
115
116
117
118
119
120
121
122
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    logits = model(image, training=True)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    accuracy = tf.metrics.accuracy(
        labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
    # Name the accuracy tensor 'train_accuracy' to demonstrate the
    # LoggingTensorHook.
    tf.identity(accuracy[1], name='train_accuracy')
    tf.summary.scalar('train_accuracy', accuracy[1])
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
123
  if mode == tf.estimator.ModeKeys.EVAL:
124
125
126
127
128
129
130
131
132
133
134
    logits = model(image, training=False)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
                    labels=tf.argmax(labels, axis=1),
                    predictions=tf.argmax(logits, axis=1)),
        })
135
136
137


def main(unused_argv):
Asim Shankar's avatar
Asim Shankar committed
138
139
140
141
  data_format = FLAGS.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
142
  mnist_classifier = tf.estimator.Estimator(
Asim Shankar's avatar
Asim Shankar committed
143
      model_fn=model_fn,
Asim Shankar's avatar
Asim Shankar committed
144
145
      model_dir=FLAGS.model_dir,
      params={
Asim Shankar's avatar
Asim Shankar committed
146
          'data_format': data_format
Asim Shankar's avatar
Asim Shankar committed
147
      })
148

149
  # Train the model
Asim Shankar's avatar
Asim Shankar committed
150
151
152
153
154
155
156
157
158
159
  def train_input_fn():
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    dataset = train_dataset(FLAGS.data_dir)
    dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
        FLAGS.train_epochs)
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

Asim Shankar's avatar
Asim Shankar committed
160
161
162
163
  # Set up training hook that logs the training accuracy every 100 steps.
  tensors_to_log = {'train_accuracy': 'train_accuracy'}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=100)
Asim Shankar's avatar
Asim Shankar committed
164
  mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
165
166

  # Evaluate the model and print results
Asim Shankar's avatar
Asim Shankar committed
167
168
169
170
  def eval_input_fn():
    return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()

  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
171
  print()
172
  print('Evaluation results:\n\t%s' % eval_results)
173

174
175
  # Export the model
  if FLAGS.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
176
177
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
178
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
179
180
    })
    mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
181

182
183

if __name__ == '__main__':
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Number of images to process in a batch')
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/mnist_data',
      help='Path to directory containing the MNIST dataset')
  parser.add_argument(
      '--model_dir',
      type=str,
      default='/tmp/mnist_model',
      help='The directory where the model will be stored.')
  parser.add_argument(
      '--train_epochs', type=int, default=40, help='Number of epochs to train.')
  parser.add_argument(
      '--data_format',
      type=str,
      default=None,
      choices=['channels_first', 'channels_last'],
      help='A flag to override the data format used in the model. channels_first '
      'provides a performance boost on GPU but is not always compatible '
      'with CPU. If left unspecified, the data format will be chosen '
      'automatically based on whether TensorFlow was built for CPU or GPU.')
  parser.add_argument(
      '--export_dir',
      type=str,
      help='The directory where the exported SavedModel will be stored.')

216
  tf.logging.set_verbosity(tf.logging.INFO)
217
218
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)