mnist.py 8.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import sys
22

Karmel Allison's avatar
Karmel Allison committed
23
import tensorflow as tf  # pylint: disable=g-bad-import-order
24

25
from official.mnist import dataset
26
from official.utils.arg_parsers import parsers
27
from official.utils.logs import hooks_helper
28
from official.utils.misc import model_helpers
29

30
LEARNING_RATE = 1e-4
31

Karmel Allison's avatar
Karmel Allison committed
32

33
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
34
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
35
36
37
38
39
40

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

41
42
43
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
44
45
46
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
47
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
48

49
50
51
52
53
54
55
56
57
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
58
59
60
61
62
63
64
65
66
67
68
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
  return tf.keras.Sequential(
      [
          l.Reshape(input_shape),
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
69
70
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
71
72
73
74
75
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
76
77
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
78
79
80
81
82
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
83
84
85
86


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
87
  model = create_model(params['data_format'])
88
89
90
91
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
92
  if mode == tf.estimator.ModeKeys.PREDICT:
93
94
95
96
97
98
99
100
101
102
103
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
104
  if mode == tf.estimator.ModeKeys.TRAIN:
105
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
106
107
108
109
110

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

111
    logits = model(image, training=True)
112
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
113
    accuracy = tf.metrics.accuracy(
114
        labels=labels, predictions=tf.argmax(logits, axis=1))
115
116
117
118

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
119
    tf.identity(accuracy[1], name='train_accuracy')
120
121

    # Save accuracy scalar to Tensorboard output.
122
    tf.summary.scalar('train_accuracy', accuracy[1])
123

124
125
126
127
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
128
  if mode == tf.estimator.ModeKeys.EVAL:
129
    logits = model(image, training=False)
130
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
131
132
133
134
135
136
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
137
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
138
        })
139
140


141
def validate_batch_size_for_multi_gpu(batch_size):
Karmel Allison's avatar
Karmel Allison committed
142
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.
143
144
145
146

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
Karmel Allison's avatar
Karmel Allison committed
147
148
149
150
151
152

  Args:
    batch_size: the number of examples processed in each training batch.

  Raises:
    ValueError: if no GPUs are found, or selected batch_size is invalid.
153
  """
Karmel Allison's avatar
Karmel Allison committed
154
  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
155
156
157
158
159

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
Karmel Allison's avatar
Karmel Allison committed
160
                     'were found. To use CPU, run without --multi_gpu.')
161

162
163
164
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
Karmel Allison's avatar
Karmel Allison committed
165
166
167
           'must be a multiple of the number of available GPUs. '
           'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
          ).format(num_gpus, batch_size, batch_size - remainder)
168
169
170
    raise ValueError(err)


171
172
173
174
def main(argv):
  parser = MNISTArgParser()
  flags = parser.parse_args(args=argv[1:])

175
176
  model_function = model_fn

177
178
  if flags.multi_gpu:
    validate_batch_size_for_multi_gpu(flags.batch_size)
179
180
181
182
183
184
185

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

186
  data_format = flags.data_format
Asim Shankar's avatar
Asim Shankar committed
187
188
189
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
190
  mnist_classifier = tf.estimator.Estimator(
191
      model_fn=model_function,
192
      model_dir=flags.model_dir,
Asim Shankar's avatar
Asim Shankar committed
193
      params={
194
          'data_format': data_format,
195
          'multi_gpu': flags.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
196
      })
197

198
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
199
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
200
201
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
202
203
204
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
205
206
    ds = dataset.train(flags.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
Asim Shankar's avatar
Asim Shankar committed
207

208
209
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
210
    ds = ds.repeat(flags.epochs_between_evals)
211
    return ds
212

Asim Shankar's avatar
Asim Shankar committed
213
  def eval_input_fn():
214
215
    return dataset.test(flags.data_dir).batch(
        flags.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
216

217
218
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
219
      flags.hooks, batch_size=flags.batch_size)
220
221

  # Train and evaluate model.
222
  for _ in range(flags.train_epochs // flags.epochs_between_evals):
223
224
225
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
226

Asim Shankar's avatar
Asim Shankar committed
227
228
    if model_helpers.past_stop_threshold(flags.stop_threshold,
                                         eval_results['accuracy']):
229
230
      break

231
  # Export the model
232
  if flags.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
233
234
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
235
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
236
    })
237
    mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
238

239

240
class MNISTArgParser(argparse.ArgumentParser):
241
  """Argument parser for running MNIST model."""
Karmel Allison's avatar
Karmel Allison committed
242

243
  def __init__(self):
244
    super(MNISTArgParser, self).__init__(parents=[
Karmel Allison's avatar
Karmel Allison committed
245
        parsers.BaseParser(),
246
247
248
        parsers.ImageModelParser(),
        parsers.ExportParser(),
    ])
249
250
251
252
253
254

    self.set_defaults(
        data_dir='/tmp/mnist_data',
        model_dir='/tmp/mnist_model',
        batch_size=100,
        train_epochs=40)
255
256
257


if __name__ == '__main__':
258
  tf.logging.set_verbosity(tf.logging.INFO)
259
  main(argv=sys.argv)