mnist.py 9.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import sys
22

Karmel Allison's avatar
Karmel Allison committed
23
import tensorflow as tf  # pylint: disable=g-bad-import-order
24

25
from official.mnist import dataset
26
from official.utils.arg_parsers import parsers
27
from official.utils.logs import hooks_helper
28
from official.utils.misc import model_helpers
29

30
LEARNING_RATE = 1e-4
31

Karmel Allison's avatar
Karmel Allison committed
32

Asim Shankar's avatar
Asim Shankar committed
33
class Model(tf.keras.Model):
Asim Shankar's avatar
Asim Shankar committed
34
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
35
36
37
38
39
40

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

Asim Shankar's avatar
Asim Shankar committed
41
  But written as a tf.keras.Model using the tf.layers API.
Asim Shankar's avatar
Asim Shankar committed
42
  """
Asim Shankar's avatar
Asim Shankar committed
43
44
45
46
47
48
49
50
51
52

  def __init__(self, data_format):
    """Creates a model for classifying a hand-written digit.

    Args:
      data_format: Either 'channels_first' or 'channels_last'.
        'channels_first' is typically faster on GPUs while 'channels_last' is
        typically faster on CPUs. See
        https://www.tensorflow.org/performance/performance_guide#data_formats
    """
Asim Shankar's avatar
Asim Shankar committed
53
    super(Model, self).__init__()
Asim Shankar's avatar
Asim Shankar committed
54
55
56
57
58
59
60
61
62
63
64
65
    if data_format == 'channels_first':
      self._input_shape = [-1, 1, 28, 28]
    else:
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]

    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.fc2 = tf.layers.Dense(10)
66
    self.dropout = tf.layers.Dropout(0.4)
Asim Shankar's avatar
Asim Shankar committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    """Add operations to classify a batch of input images.

    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.

    Returns:
      A logits Tensor with shape [<batch_size>, 10].
    """
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)


def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = Model(params['data_format'])
95
96
97
98
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
99
  if mode == tf.estimator.ModeKeys.PREDICT:
100
101
102
103
104
105
106
107
108
109
110
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
111
  if mode == tf.estimator.ModeKeys.TRAIN:
112
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
113
114
115
116
117

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

118
    logits = model(image, training=True)
119
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
120
    accuracy = tf.metrics.accuracy(
121
        labels=labels, predictions=tf.argmax(logits, axis=1))
122
123
124
125

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
126
    tf.identity(accuracy[1], name='train_accuracy')
127
128

    # Save accuracy scalar to Tensorboard output.
129
    tf.summary.scalar('train_accuracy', accuracy[1])
130

131
132
133
134
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
135
  if mode == tf.estimator.ModeKeys.EVAL:
136
    logits = model(image, training=False)
137
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
138
139
140
141
142
143
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Mark Daoust's avatar
Mark Daoust committed
144
                    labels=labels,
145
146
                    predictions=tf.argmax(logits, axis=1)),
        })
147
148


149
def validate_batch_size_for_multi_gpu(batch_size):
Karmel Allison's avatar
Karmel Allison committed
150
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.
151
152
153
154

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
Karmel Allison's avatar
Karmel Allison committed
155
156
157
158
159
160

  Args:
    batch_size: the number of examples processed in each training batch.

  Raises:
    ValueError: if no GPUs are found, or selected batch_size is invalid.
161
  """
Karmel Allison's avatar
Karmel Allison committed
162
  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
163
164
165
166
167

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
Karmel Allison's avatar
Karmel Allison committed
168
                     'were found. To use CPU, run without --multi_gpu.')
169

170
171
172
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
Karmel Allison's avatar
Karmel Allison committed
173
174
175
           'must be a multiple of the number of available GPUs. '
           'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
          ).format(num_gpus, batch_size, batch_size - remainder)
176
177
178
    raise ValueError(err)


179
180
181
182
def main(argv):
  parser = MNISTArgParser()
  flags = parser.parse_args(args=argv[1:])

183
184
  model_function = model_fn

185
186
  if flags.multi_gpu:
    validate_batch_size_for_multi_gpu(flags.batch_size)
187
188
189
190
191
192
193

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

194
  data_format = flags.data_format
Asim Shankar's avatar
Asim Shankar committed
195
196
197
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
198
  mnist_classifier = tf.estimator.Estimator(
199
      model_fn=model_function,
200
      model_dir=flags.model_dir,
Asim Shankar's avatar
Asim Shankar committed
201
      params={
202
          'data_format': data_format,
203
          'multi_gpu': flags.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
204
      })
205

206
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
207
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
208
209
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
210
211
212
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
213
214
    ds = dataset.train(flags.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
Asim Shankar's avatar
Asim Shankar committed
215

216
217
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
218
    ds = ds.repeat(flags.epochs_between_evals)
219
    return ds
220

Asim Shankar's avatar
Asim Shankar committed
221
  def eval_input_fn():
222
223
    return dataset.test(flags.data_dir).batch(
        flags.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
224

225
226
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
227
      flags.hooks, batch_size=flags.batch_size)
228
229

  # Train and evaluate model.
230
  for _ in range(flags.train_epochs // flags.epochs_between_evals):
231
232
233
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
234

235
236
237
238
    if model_helpers.past_stop_threshold(
        flags.stop_threshold, eval_results['accuracy']):
      break

239
  # Export the model
240
  if flags.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
241
242
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
243
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
244
    })
245
    mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
246

247

248
class MNISTArgParser(argparse.ArgumentParser):
249
  """Argument parser for running MNIST model."""
Karmel Allison's avatar
Karmel Allison committed
250

251
  def __init__(self):
252
    super(MNISTArgParser, self).__init__(parents=[
Karmel Allison's avatar
Karmel Allison committed
253
        parsers.BaseParser(),
254
255
256
        parsers.ImageModelParser(),
        parsers.ExportParser(),
    ])
257
258
259
260
261
262

    self.set_defaults(
        data_dir='/tmp/mnist_data',
        model_dir='/tmp/mnist_model',
        batch_size=100,
        train_epochs=40)
263
264
265


if __name__ == '__main__':
266
  tf.logging.set_verbosity(tf.logging.INFO)
267
  main(argv=sys.argv)