mnist.py 9.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
21
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
22
import tensorflow as tf  # pylint: disable=g-bad-import-order
23

24
from official.mnist import dataset
25
from official.utils.flags import core as flags_core
26
from official.utils.logs import hooks_helper
27
from official.utils.misc import model_helpers
28

29

30
LEARNING_RATE = 1e-4
31

Karmel Allison's avatar
Karmel Allison committed
32

33
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
34
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
35
36
37
38
39
40

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

41
42
43
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
44
45
46
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
47
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
48

49
50
51
52
53
54
55
56
57
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
58
59
60
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
61
62
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
Asim Shankar's avatar
Asim Shankar committed
63
64
  return tf.keras.Sequential(
      [
65
66
67
          l.Reshape(
              target_shape=input_shape,
              input_shape=(28 * 28,)),
Asim Shankar's avatar
Asim Shankar committed
68
69
70
71
72
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
73
74
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
75
76
77
78
79
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
80
81
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
82
83
84
85
86
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
87
88


89
90
91
92
93
94
95
96
97
98
def define_mnist_flags():
  flags_core.define_base()
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)
  flags_core.set_defaults(data_dir='/tmp/mnist_data',
                          model_dir='/tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40)


Asim Shankar's avatar
Asim Shankar committed
99
100
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
101
  model = create_model(params['data_format'])
102
103
104
105
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
106
  if mode == tf.estimator.ModeKeys.PREDICT:
107
108
109
110
111
112
113
114
115
116
117
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
118
  if mode == tf.estimator.ModeKeys.TRAIN:
119
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
120
121
122
123
124

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

125
    logits = model(image, training=True)
126
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
127
    accuracy = tf.metrics.accuracy(
128
        labels=labels, predictions=tf.argmax(logits, axis=1))
129
130
131
132

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
133
    tf.identity(accuracy[1], name='train_accuracy')
134
135

    # Save accuracy scalar to Tensorboard output.
136
    tf.summary.scalar('train_accuracy', accuracy[1])
137

138
139
140
141
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
142
  if mode == tf.estimator.ModeKeys.EVAL:
143
    logits = model(image, training=False)
144
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
145
146
147
148
149
150
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
151
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
152
        })
153
154


155
def validate_batch_size_for_multi_gpu(batch_size):
Karmel Allison's avatar
Karmel Allison committed
156
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.
157
158
159
160

  Note that this should eventually be handled by replicate_model_fn
  directly. Multi-GPU support is currently experimental, however,
  so doing the work here until that feature is in place.
Karmel Allison's avatar
Karmel Allison committed
161
162
163
164
165
166

  Args:
    batch_size: the number of examples processed in each training batch.

  Raises:
    ValueError: if no GPUs are found, or selected batch_size is invalid.
167
  """
Karmel Allison's avatar
Karmel Allison committed
168
  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
169
170
171
172
173

  local_device_protos = device_lib.list_local_devices()
  num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
  if not num_gpus:
    raise ValueError('Multi-GPU mode was specified, but no GPUs '
Karmel Allison's avatar
Karmel Allison committed
174
                     'were found. To use CPU, run without --multi_gpu.')
175

176
177
178
  remainder = batch_size % num_gpus
  if remainder:
    err = ('When running with multiple GPUs, batch size '
Karmel Allison's avatar
Karmel Allison committed
179
180
181
           'must be a multiple of the number of available GPUs. '
           'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
          ).format(num_gpus, batch_size, batch_size - remainder)
182
183
184
    raise ValueError(err)


185
186
187
188
189
190
191
def run_mnist(flags_obj):
  """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """

192
193
  model_function = model_fn

194
195
  if flags_obj.multi_gpu:
    validate_batch_size_for_multi_gpu(flags_obj.batch_size)
196
197
198
199
200
201
202

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN)

203
  data_format = flags_obj.data_format
Asim Shankar's avatar
Asim Shankar committed
204
205
206
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
207
  mnist_classifier = tf.estimator.Estimator(
208
      model_fn=model_function,
209
      model_dir=flags_obj.model_dir,
Asim Shankar's avatar
Asim Shankar committed
210
      params={
211
          'data_format': data_format,
212
          'multi_gpu': flags_obj.multi_gpu
Asim Shankar's avatar
Asim Shankar committed
213
      })
214

215
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
216
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
217
218
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
219
220
221
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
222
223
    ds = dataset.train(flags_obj.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
Asim Shankar's avatar
Asim Shankar committed
224

225
226
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
227
    ds = ds.repeat(flags_obj.epochs_between_evals)
228
    return ds
229

Asim Shankar's avatar
Asim Shankar committed
230
  def eval_input_fn():
231
232
    return dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
233

234
235
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
236
      flags_obj.hooks, batch_size=flags_obj.batch_size)
237
238

  # Train and evaluate model.
239
  for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
240
241
242
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
243

244
    if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
Asim Shankar's avatar
Asim Shankar committed
245
                                         eval_results['accuracy']):
246
247
      break

248
  # Export the model
249
  if flags_obj.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
250
251
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
252
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
253
    })
254
    mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
255
256


257
258
259
260
def main(_):
  run_mnist(flags.FLAGS)


261
if __name__ == '__main__':
262
  tf.logging.set_verbosity(tf.logging.INFO)
263
264
  define_mnist_flags()
  absl_app.run(main)