mnist.py 8.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
21
import sys
22

23
24
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
25
import tensorflow as tf  # pylint: disable=g-bad-import-order
26

27
from official.mnist import dataset
28
from official.utils.flags import core as flags_core
29
from official.utils.logs import hooks_helper
30
from official.utils.misc import model_helpers
31

32

33
LEARNING_RATE = 1e-4
34

Karmel Allison's avatar
Karmel Allison committed
35

36
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
37
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
38
39
40
41
42
43

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

44
45
46
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
47
48
49
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
50
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
51

52
53
54
55
56
57
58
59
60
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
61
62
63
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
64
65
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
Asim Shankar's avatar
Asim Shankar committed
66
67
  return tf.keras.Sequential(
      [
68
69
70
          l.Reshape(
              target_shape=input_shape,
              input_shape=(28 * 28,)),
Asim Shankar's avatar
Asim Shankar committed
71
72
73
74
75
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
76
77
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
78
79
80
81
82
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
83
84
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
85
86
87
88
89
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
90
91


92
93
94
95
96
97
98
99
100
101
def define_mnist_flags():
  flags_core.define_base()
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)
  flags_core.set_defaults(data_dir='/tmp/mnist_data',
                          model_dir='/tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40)


Asim Shankar's avatar
Asim Shankar committed
102
103
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
104
  model = create_model(params['data_format'])
105
106
107
108
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
109
  if mode == tf.estimator.ModeKeys.PREDICT:
110
111
112
113
114
115
116
117
118
119
120
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
121
  if mode == tf.estimator.ModeKeys.TRAIN:
122
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
123
124
125
126
127

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

128
    logits = model(image, training=True)
129
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
130
    accuracy = tf.metrics.accuracy(
131
        labels=labels, predictions=tf.argmax(logits, axis=1))
132
133
134
135

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
136
    tf.identity(accuracy[1], name='train_accuracy')
137
138

    # Save accuracy scalar to Tensorboard output.
139
    tf.summary.scalar('train_accuracy', accuracy[1])
140

141
142
143
144
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
145
  if mode == tf.estimator.ModeKeys.EVAL:
146
    logits = model(image, training=False)
147
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
148
149
150
151
152
153
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
154
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
155
        })
156
157


158
def validate_batch_size_for_multi_gpu(batch_size):
Karmel Allison's avatar
Karmel Allison committed
159
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.
160
161
162
163

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
Karmel Allison's avatar
Karmel Allison committed
164
165
166
167
168
169

  Args:
    batch_size: the number of examples processed in each training batch.

  Raises:
    ValueError: if no GPUs are found, or selected batch_size is invalid.
170
  """
Karmel Allison's avatar
Karmel Allison committed
171
  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
172
173
174
175
176

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
Karmel Allison's avatar
Karmel Allison committed
177
                     'were found. To use CPU, run without --multi_gpu.')
178

179
180
181
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
Karmel Allison's avatar
Karmel Allison committed
182
183
184
           'must be a multiple of the number of available GPUs. '
           'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
          ).format(num_gpus, batch_size, batch_size - remainder)
185
186
187
    raise ValueError(err)


188
def main(flags_obj):
189
190
  model_function = model_fn

191
192
  if flags_obj.multi_gpu:
    validate_batch_size_for_multi_gpu(flags_obj.batch_size)
193
194
195
196
197
198
199

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

200
  data_format = flags_obj.data_format
Asim Shankar's avatar
Asim Shankar committed
201
202
203
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
204
  mnist_classifier = tf.estimator.Estimator(
205
      model_fn=model_function,
206
      model_dir=flags_obj.model_dir,
Asim Shankar's avatar
Asim Shankar committed
207
      params={
208
          'data_format': data_format,
209
          'multi_gpu': flags_obj.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
210
      })
211

212
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
213
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
214
215
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
216
217
218
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
219
220
    ds = dataset.train(flags_obj.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
Asim Shankar's avatar
Asim Shankar committed
221

222
223
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
224
    ds = ds.repeat(flags_obj.epochs_between_evals)
225
    return ds
226

Asim Shankar's avatar
Asim Shankar committed
227
  def eval_input_fn():
228
229
    return dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
230

231
232
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
233
      flags_obj.hooks, batch_size=flags_obj.batch_size)
234
235

  # Train and evaluate model.
236
  for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
237
238
239
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
240

241
    if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
Asim Shankar's avatar
Asim Shankar committed
242
                                         eval_results['accuracy']):
243
244
      break

245
  # Export the model
246
  if flags_obj.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
247
248
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
249
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
250
    })
251
    mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
252
253
254


if __name__ == '__main__':
255
  tf.logging.set_verbosity(tf.logging.INFO)
256
257
  define_mnist_flags()
  absl_app.run(main)