mnist.py 8.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import sys
22

Karmel Allison's avatar
Karmel Allison committed
23
import tensorflow as tf  # pylint: disable=g-bad-import-order
24

25
from official.mnist import dataset
26
from official.utils.arg_parsers import parsers
27
from official.utils.logs import hooks_helper
28
from official.utils.misc import model_helpers
29

30
LEARNING_RATE = 1e-4
31

Karmel Allison's avatar
Karmel Allison committed
32

33
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
34
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
35
36
37
38
39
40

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

41
42
43
44
45
46
47
  But uses the tf.keras API.

  Args:
    data_format: Either 'channels_first' or 'channels_last'.
      'channels_first' is typically faster on GPUs while 'channels_last' is
      typically faster on CPUs. See
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  Returns:
    A tf.keras.Model.
  """
  input_shape = None
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

  L = tf.keras.layers
  max_pool = L.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
  return tf.keras.Sequential([
      L.Reshape(input_shape),
      L.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
      max_pool,
      L.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
      max_pool,
      L.Flatten(),
      L.Dense(1024, activation=tf.nn.relu),
      L.Dropout(0.4),
      L.Dense(10)])
Asim Shankar's avatar
Asim Shankar committed
71
72
73
74


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
75
  model = create_model(params['data_format'])
76
77
78
79
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
80
  if mode == tf.estimator.ModeKeys.PREDICT:
81
82
83
84
85
86
87
88
89
90
91
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
92
  if mode == tf.estimator.ModeKeys.TRAIN:
93
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
94
95
96
97
98

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

99
    logits = model(image, training=True)
100
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
101
    accuracy = tf.metrics.accuracy(
102
        labels=labels, predictions=tf.argmax(logits, axis=1))
103
104
105
106

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
107
    tf.identity(accuracy[1], name='train_accuracy')
108
109

    # Save accuracy scalar to Tensorboard output.
110
    tf.summary.scalar('train_accuracy', accuracy[1])
111

112
113
114
115
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
116
  if mode == tf.estimator.ModeKeys.EVAL:
117
    logits = model(image, training=False)
118
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
119
120
121
122
123
124
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Mark Daoust's avatar
Mark Daoust committed
125
                    labels=labels,
126
127
                    predictions=tf.argmax(logits, axis=1)),
        })
128
129


130
def validate_batch_size_for_multi_gpu(batch_size):
Karmel Allison's avatar
Karmel Allison committed
131
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.
132
133
134
135

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
Karmel Allison's avatar
Karmel Allison committed
136
137
138
139
140
141

  Args:
    batch_size: the number of examples processed in each training batch.

  Raises:
    ValueError: if no GPUs are found, or selected batch_size is invalid.
142
  """
Karmel Allison's avatar
Karmel Allison committed
143
  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
144
145
146
147
148

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
Karmel Allison's avatar
Karmel Allison committed
149
                     'were found. To use CPU, run without --multi_gpu.')
150

151
152
153
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
Karmel Allison's avatar
Karmel Allison committed
154
155
156
           'must be a multiple of the number of available GPUs. '
           'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
          ).format(num_gpus, batch_size, batch_size - remainder)
157
158
159
    raise ValueError(err)


160
161
162
163
def main(argv):
  parser = MNISTArgParser()
  flags = parser.parse_args(args=argv[1:])

164
165
  model_function = model_fn

166
167
  if flags.multi_gpu:
    validate_batch_size_for_multi_gpu(flags.batch_size)
168
169
170
171
172
173
174

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

175
  data_format = flags.data_format
Asim Shankar's avatar
Asim Shankar committed
176
177
178
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
179
  mnist_classifier = tf.estimator.Estimator(
180
      model_fn=model_function,
181
      model_dir=flags.model_dir,
Asim Shankar's avatar
Asim Shankar committed
182
      params={
183
          'data_format': data_format,
184
          'multi_gpu': flags.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
185
      })
186

187
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
188
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
189
190
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
191
192
193
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
194
195
    ds = dataset.train(flags.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
Asim Shankar's avatar
Asim Shankar committed
196

197
198
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
199
    ds = ds.repeat(flags.epochs_between_evals)
200
    return ds
201

Asim Shankar's avatar
Asim Shankar committed
202
  def eval_input_fn():
203
204
    return dataset.test(flags.data_dir).batch(
        flags.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
205

206
207
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
208
      flags.hooks, batch_size=flags.batch_size)
209
210

  # Train and evaluate model.
211
  for _ in range(flags.train_epochs // flags.epochs_between_evals):
212
213
214
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
215

216
217
218
219
    if model_helpers.past_stop_threshold(
        flags.stop_threshold, eval_results['accuracy']):
      break

220
  # Export the model
221
  if flags.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
222
223
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
224
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
225
    })
226
    mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
227

228

229
class MNISTArgParser(argparse.ArgumentParser):
230
  """Argument parser for running MNIST model."""
Karmel Allison's avatar
Karmel Allison committed
231

232
  def __init__(self):
233
    super(MNISTArgParser, self).__init__(parents=[
Karmel Allison's avatar
Karmel Allison committed
234
        parsers.BaseParser(),
235
236
237
        parsers.ImageModelParser(),
        parsers.ExportParser(),
    ])
238
239
240
241
242
243

    self.set_defaults(
        data_dir='/tmp/mnist_data',
        model_dir='/tmp/mnist_model',
        batch_size=100,
        train_epochs=40)
244
245
246


if __name__ == '__main__':
247
  tf.logging.set_verbosity(tf.logging.INFO)
248
  main(argv=sys.argv)