mnist.py 9.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import sys
22
23

import tensorflow as tf
24

25
from official.mnist import dataset
26
27
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper
28

29
LEARNING_RATE = 1e-4
30

Asim Shankar's avatar
Asim Shankar committed
31
class Model(tf.keras.Model):
Asim Shankar's avatar
Asim Shankar committed
32
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
33
34
35
36
37
38

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

Asim Shankar's avatar
Asim Shankar committed
39
  But written as a tf.keras.Model using the tf.layers API.
Asim Shankar's avatar
Asim Shankar committed
40
  """
Asim Shankar's avatar
Asim Shankar committed
41
42
43
44
45
46
47
48
49
50

  def __init__(self, data_format):
    """Creates a model for classifying a hand-written digit.

    Args:
      data_format: Either 'channels_first' or 'channels_last'.
        'channels_first' is typically faster on GPUs while 'channels_last' is
        typically faster on CPUs. See
        https://www.tensorflow.org/performance/performance_guide#data_formats
    """
Asim Shankar's avatar
Asim Shankar committed
51
    super(Model, self).__init__()
Asim Shankar's avatar
Asim Shankar committed
52
53
54
55
56
57
58
59
60
61
62
63
    if data_format == 'channels_first':
      self._input_shape = [-1, 1, 28, 28]
    else:
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]

    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.fc2 = tf.layers.Dense(10)
64
    self.dropout = tf.layers.Dropout(0.4)
Asim Shankar's avatar
Asim Shankar committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    """Add operations to classify a batch of input images.

    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.

    Returns:
      A logits Tensor with shape [<batch_size>, 10].
    """
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = Model(params['data_format'])
93
94
95
96
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
97
  if mode == tf.estimator.ModeKeys.PREDICT:
98
99
100
101
102
103
104
105
106
107
108
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
109
  if mode == tf.estimator.ModeKeys.TRAIN:
110
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
111
112
113
114
115

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

116
    logits = model(image, training=True)
117
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
118
    accuracy = tf.metrics.accuracy(
119
        labels=labels, predictions=tf.argmax(logits, axis=1))
120
121
122
123

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
124
    tf.identity(accuracy[1], name='train_accuracy')
125
126

    # Save accuracy scalar to Tensorboard output.
127
    tf.summary.scalar('train_accuracy', accuracy[1])
128

129
130
131
132
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
133
  if mode == tf.estimator.ModeKeys.EVAL:
134
    logits = model(image, training=False)
135
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
136
137
138
139
140
141
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Mark Daoust's avatar
Mark Daoust committed
142
                    labels=labels,
143
144
                    predictions=tf.argmax(logits, axis=1)),
        })
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def validate_batch_size_for_multi_gpu(batch_size):
  """For multi-gpu, batch-size must be a multiple of the number of
  available GPUs.

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
  """
  from tensorflow.python.client import device_lib

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
      'were found. To use CPU, run without --multi_gpu.')
162

163
164
165
166
167
168
169
170
171
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
      'must be a multiple of the number of available GPUs. '
      'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
      ).format(num_gpus, batch_size, batch_size - remainder)
    raise ValueError(err)


172
def main(unused_argv):
173
174
175
176
177
178
179
180
181
182
183
  model_function = model_fn

  if FLAGS.multi_gpu:
    validate_batch_size_for_multi_gpu(FLAGS.batch_size)

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

Asim Shankar's avatar
Asim Shankar committed
184
185
186
187
  data_format = FLAGS.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
188
  mnist_classifier = tf.estimator.Estimator(
189
      model_fn=model_function,
Asim Shankar's avatar
Asim Shankar committed
190
191
      model_dir=FLAGS.model_dir,
      params={
192
193
          'data_format': data_format,
          'multi_gpu': FLAGS.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
194
      })
195

196
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
197
198
199
200
  def train_input_fn():
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
201
    ds = dataset.train(FLAGS.data_dir)
202
    ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size)
Asim Shankar's avatar
Asim Shankar committed
203

204
205
206
207
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
    ds = ds.repeat(FLAGS.epochs_between_evals)
    return ds
208

Asim Shankar's avatar
Asim Shankar committed
209
  def eval_input_fn():
210
211
    return dataset.test(FLAGS.data_dir).batch(
        FLAGS.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
212

213
214
215
216
217
218
219
220
221
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
      FLAGS.hooks, batch_size=FLAGS.batch_size)

  # Train and evaluate model.
  for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
222

223
224
  # Export the model
  if FLAGS.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
225
226
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
227
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
228
229
    })
    mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
230

231

232
class MNISTArgParser(argparse.ArgumentParser):
233
  """Argument parser for running MNIST model."""
234
  def __init__(self):
235
236
237
    super(MNISTArgParser, self).__init__(parents=[
      parsers.BaseParser(),
      parsers.ImageModelParser()])
238

239
240
241
    self.add_argument(
        '--export_dir',
        type=str,
242
243
244
245
246
247
248
249
250
        help='[default: %(default)s] If set, a SavedModel serialization of the '
             'model will be exported to this directory at the end of training. '
             'See the README for more details and relevant links.')

    self.set_defaults(
        data_dir='/tmp/mnist_data',
        model_dir='/tmp/mnist_model',
        batch_size=100,
        train_epochs=40)
251
252
253


if __name__ == '__main__':
254
  tf.logging.set_verbosity(tf.logging.INFO)
255
  parser = MNISTArgParser()
256
257
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)