mnist.py 8.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import sys
22

Karmel Allison's avatar
Karmel Allison committed
23
import tensorflow as tf  # pylint: disable=g-bad-import-order
24

25
from official.mnist import dataset
26
from official.utils.arg_parsers import parsers
27
from official.utils.logs import hooks_helper
28
from official.utils.misc import model_helpers
29

30
LEARNING_RATE = 1e-4
31

Karmel Allison's avatar
Karmel Allison committed
32

33
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
34
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
35
36
37
38
39
40

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

41
42
43
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
44
45
46
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
47
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
48

49
50
51
52
53
54
55
56
57
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
58
59
60
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
61
62
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
Asim Shankar's avatar
Asim Shankar committed
63
64
65
66
67
68
69
70
  return tf.keras.Sequential(
      [
          l.Reshape(input_shape),
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
71
72
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
73
74
75
76
77
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
78
79
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
80
81
82
83
84
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
85
86
87
88


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
89
  model = create_model(params['data_format'])
90
91
92
93
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
94
  if mode == tf.estimator.ModeKeys.PREDICT:
95
96
97
98
99
100
101
102
103
104
105
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
106
  if mode == tf.estimator.ModeKeys.TRAIN:
107
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
108
109
110
111
112

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

113
    logits = model(image, training=True)
114
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
115
    accuracy = tf.metrics.accuracy(
116
        labels=labels, predictions=tf.argmax(logits, axis=1))
117
118
119
120

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
121
    tf.identity(accuracy[1], name='train_accuracy')
122
123

    # Save accuracy scalar to Tensorboard output.
124
    tf.summary.scalar('train_accuracy', accuracy[1])
125

126
127
128
129
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
130
  if mode == tf.estimator.ModeKeys.EVAL:
131
    logits = model(image, training=False)
132
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
133
134
135
136
137
138
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
139
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
140
        })
141
142


143
def validate_batch_size_for_multi_gpu(batch_size):
Karmel Allison's avatar
Karmel Allison committed
144
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.
145
146
147
148

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
Karmel Allison's avatar
Karmel Allison committed
149
150
151
152
153
154

  Args:
    batch_size: the number of examples processed in each training batch.

  Raises:
    ValueError: if no GPUs are found, or selected batch_size is invalid.
155
  """
Karmel Allison's avatar
Karmel Allison committed
156
  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
157
158
159
160
161

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
Karmel Allison's avatar
Karmel Allison committed
162
                     'were found. To use CPU, run without --multi_gpu.')
163

164
165
166
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
Karmel Allison's avatar
Karmel Allison committed
167
168
169
           'must be a multiple of the number of available GPUs. '
           'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
          ).format(num_gpus, batch_size, batch_size - remainder)
170
171
172
    raise ValueError(err)


173
174
175
176
def main(argv):
  parser = MNISTArgParser()
  flags = parser.parse_args(args=argv[1:])

177
178
  model_function = model_fn

179
180
  if flags.multi_gpu:
    validate_batch_size_for_multi_gpu(flags.batch_size)
181
182
183
184
185
186
187

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

188
  data_format = flags.data_format
Asim Shankar's avatar
Asim Shankar committed
189
190
191
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
192
  mnist_classifier = tf.estimator.Estimator(
193
      model_fn=model_function,
194
      model_dir=flags.model_dir,
Asim Shankar's avatar
Asim Shankar committed
195
      params={
196
          'data_format': data_format,
197
          'multi_gpu': flags.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
198
      })
199

200
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
201
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
202
203
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
204
205
206
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
207
208
    ds = dataset.train(flags.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
Asim Shankar's avatar
Asim Shankar committed
209

210
211
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
212
    ds = ds.repeat(flags.epochs_between_evals)
213
    return ds
214

Asim Shankar's avatar
Asim Shankar committed
215
  def eval_input_fn():
216
217
    return dataset.test(flags.data_dir).batch(
        flags.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
218

219
220
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
221
      flags.hooks, batch_size=flags.batch_size)
222
223

  # Train and evaluate model.
224
  for _ in range(flags.train_epochs // flags.epochs_between_evals):
225
226
227
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
228

Asim Shankar's avatar
Asim Shankar committed
229
230
    if model_helpers.past_stop_threshold(flags.stop_threshold,
                                         eval_results['accuracy']):
231
232
      break

233
  # Export the model
234
  if flags.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
235
236
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
237
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
238
    })
239
    mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
240

241

242
class MNISTArgParser(argparse.ArgumentParser):
243
  """Argument parser for running MNIST model."""
Karmel Allison's avatar
Karmel Allison committed
244

245
  def __init__(self):
246
    super(MNISTArgParser, self).__init__(parents=[
Karmel Allison's avatar
Karmel Allison committed
247
        parsers.BaseParser(),
248
249
250
        parsers.ImageModelParser(),
        parsers.ExportParser(),
    ])
251
252
253
254
255
256

    self.set_defaults(
        data_dir='/tmp/mnist_data',
        model_dir='/tmp/mnist_model',
        batch_size=100,
        train_epochs=40)
257
258
259


if __name__ == '__main__':
260
  tf.logging.set_verbosity(tf.logging.INFO)
261
  main(argv=sys.argv)