mnist.py 8.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
22
import sys
23
24

import tensorflow as tf
Asim Shankar's avatar
Asim Shankar committed
25
from tensorflow.examples.tutorials.mnist import input_data
26
27
28
29

parser = argparse.ArgumentParser()

# Basic model parameters.
Asim Shankar's avatar
Asim Shankar committed
30
31
32
33
34
parser.add_argument(
    '--batch_size',
    type=int,
    default=100,
    help='Number of images to process in a batch')
35

Asim Shankar's avatar
Asim Shankar committed
36
37
38
39
40
parser.add_argument(
    '--data_dir',
    type=str,
    default='/tmp/mnist_data',
    help='Path to directory containing the MNIST dataset')
41

Asim Shankar's avatar
Asim Shankar committed
42
43
44
45
46
parser.add_argument(
    '--model_dir',
    type=str,
    default='/tmp/mnist_model',
    help='The directory where the model will be stored.')
47

Asim Shankar's avatar
Asim Shankar committed
48
49
parser.add_argument(
    '--train_epochs', type=int, default=40, help='Number of epochs to train.')
50
51

parser.add_argument(
Asim Shankar's avatar
Asim Shankar committed
52
53
54
    '--data_format',
    type=str,
    default=None,
55
56
    choices=['channels_first', 'channels_last'],
    help='A flag to override the data format used in the model. channels_first '
Asim Shankar's avatar
Asim Shankar committed
57
58
59
60
    'provides a performance boost on GPU but is not always compatible '
    'with CPU. If left unspecified, the data format will be chosen '
    'automatically based on whether TensorFlow was built for CPU or GPU.')

61
62
63
64
parser.add_argument(
    '--export_dir',
    type=str,
    help='The directory where the exported SavedModel will be stored.')
65

Asim Shankar's avatar
Asim Shankar committed
66
67
68
69
def train_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
  data = input_data.read_data_sets(data_dir, one_hot=True).train
  return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
70
71


Asim Shankar's avatar
Asim Shankar committed
72
73
74
75
def eval_dataset(data_dir):
  """Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
  data = input_data.read_data_sets(data_dir, one_hot=True).test
  return tf.data.Dataset.from_tensors((data.images, data.labels))
76
77


78
def mnist_model(inputs, mode, data_format):
79
80
81
82
83
84
  """Takes the MNIST inputs and mode and outputs a tensor of logits."""
  # Input Layer
  # Reshape X to 4-D tensor: [batch_size, width, height, channels]
  # MNIST images are 28x28 pixels, and have one color channel
  inputs = tf.reshape(inputs, [-1, 28, 28, 1])

85
  if data_format is None:
86
87
    # When running on GPU, transpose the data from channels_last (NHWC) to
    # channels_first (NCHW) to improve performance.
88
    # See https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
89
90
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
91
92

  if data_format == 'channels_first':
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    inputs = tf.transpose(inputs, [0, 3, 1, 2])

  # Convolutional Layer #1
  # Computes 32 features using a 5x5 filter with ReLU activation.
  # Padding is added to preserve width and height.
  # Input Tensor Shape: [batch_size, 28, 28, 1]
  # Output Tensor Shape: [batch_size, 28, 28, 32]
  conv1 = tf.layers.conv2d(
      inputs=inputs,
      filters=32,
      kernel_size=[5, 5],
      padding='same',
      activation=tf.nn.relu,
      data_format=data_format)

  # Pooling Layer #1
  # First max pooling layer with a 2x2 filter and stride of 2
  # Input Tensor Shape: [batch_size, 28, 28, 32]
  # Output Tensor Shape: [batch_size, 14, 14, 32]
Asim Shankar's avatar
Asim Shankar committed
112
113
  pool1 = tf.layers.max_pooling2d(
      inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

  # Convolutional Layer #2
  # Computes 64 features using a 5x5 filter.
  # Padding is added to preserve width and height.
  # Input Tensor Shape: [batch_size, 14, 14, 32]
  # Output Tensor Shape: [batch_size, 14, 14, 64]
  conv2 = tf.layers.conv2d(
      inputs=pool1,
      filters=64,
      kernel_size=[5, 5],
      padding='same',
      activation=tf.nn.relu,
      data_format=data_format)

  # Pooling Layer #2
  # Second max pooling layer with a 2x2 filter and stride of 2
  # Input Tensor Shape: [batch_size, 14, 14, 64]
  # Output Tensor Shape: [batch_size, 7, 7, 64]
Asim Shankar's avatar
Asim Shankar committed
132
133
  pool2 = tf.layers.max_pooling2d(
      inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format)
134
135
136
137
138
139
140
141
142
143

  # Flatten tensor into a batch of vectors
  # Input Tensor Shape: [batch_size, 7, 7, 64]
  # Output Tensor Shape: [batch_size, 7 * 7 * 64]
  pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

  # Dense Layer
  # Densely connected layer with 1024 neurons
  # Input Tensor Shape: [batch_size, 7 * 7 * 64]
  # Output Tensor Shape: [batch_size, 1024]
Asim Shankar's avatar
Asim Shankar committed
144
  dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
145
146
147
148
149
150
151
152
153
154
155
156

  # Add dropout operation; 0.6 probability that element will be kept
  dropout = tf.layers.dropout(
      inputs=dense, rate=0.4, training=(mode == tf.estimator.ModeKeys.TRAIN))

  # Logits layer
  # Input Tensor Shape: [batch_size, 1024]
  # Output Tensor Shape: [batch_size, 10]
  logits = tf.layers.dense(inputs=dropout, units=10)
  return logits


157
def mnist_model_fn(features, labels, mode, params):
158
  """Model function for MNIST."""
159
160
161
  if mode == tf.estimator.ModeKeys.PREDICT and isinstance(features,dict):
    features = features['image_raw']
  
162
  logits = mnist_model(features, mode, params['data_format'])
163
164
165
166
167
168
169

  predictions = {
      'classes': tf.argmax(input=logits, axis=1),
      'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
  }

  if mode == tf.estimator.ModeKeys.PREDICT:
170
171
172
    export_outputs={'classify': tf.estimator.export.PredictOutput(predictions)}
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions,
                                      export_outputs=export_outputs)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

  loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

  # Configure the training op
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step())
  else:
    train_op = None

  accuracy = tf.metrics.accuracy(
      tf.argmax(labels, axis=1), predictions['classes'])
  metrics = {'accuracy': accuracy}

  # Create a tensor named train_accuracy for logging purposes
  tf.identity(accuracy[1], name='train_accuracy')
  tf.summary.scalar('train_accuracy', accuracy[1])

  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=metrics)


def main(unused_argv):
  # Create the Estimator
  mnist_classifier = tf.estimator.Estimator(
Asim Shankar's avatar
Asim Shankar committed
202
203
204
205
206
      model_fn=mnist_model_fn,
      model_dir=FLAGS.model_dir,
      params={
          'data_format': FLAGS.data_format
      })
207

208
  # Set up training hook that logs the training accuracy every 100 steps.
Asim Shankar's avatar
Asim Shankar committed
209
  tensors_to_log = {'train_accuracy': 'train_accuracy'}
210
211
212
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=100)

213
  # Train the model
Asim Shankar's avatar
Asim Shankar committed
214
215
216
217
218
219
220
221
222
223
224
  def train_input_fn():
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    dataset = train_dataset(FLAGS.data_dir)
    dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
        FLAGS.train_epochs)
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

  mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
225
226

  # Evaluate the model and print results
Asim Shankar's avatar
Asim Shankar committed
227
228
229
230
  def eval_input_fn():
    return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()

  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
231
  print()
232
  print('Evaluation results:\n\t%s' % eval_results)
233

234
235
236
237
238
239
240
  # Export the model
  if FLAGS.export_dir is not None:
    image = tf.placeholder(tf.float32,[None, 28, 28])
    serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
        {"image_raw":image})
    mnist_classifier.export_savedmodel(FLAGS.export_dir, serving_input_fn)

241
242
243

if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
244
245
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)