mnist.py 8.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
21
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
22
import tensorflow as tf  # pylint: disable=g-bad-import-order
23

24
from official.mnist import dataset
25
from official.utils.flags import core as flags_core
26
from official.utils.logs import hooks_helper
27
from official.utils.misc import distribution_utils
28
from official.utils.misc import model_helpers
29

30

31
LEARNING_RATE = 1e-4
32

Karmel Allison's avatar
Karmel Allison committed
33

34
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
35
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
36
37
38
39
40
41

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

42
43
44
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
45
46
47
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
48
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
49

50
51
52
53
54
55
56
57
58
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
59
60
61
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
62
63
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
Asim Shankar's avatar
Asim Shankar committed
64
65
  return tf.keras.Sequential(
      [
66
67
68
          l.Reshape(
              target_shape=input_shape,
              input_shape=(28 * 28,)),
Asim Shankar's avatar
Asim Shankar committed
69
70
71
72
73
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
74
75
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
76
77
78
79
80
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
81
82
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
83
84
85
86
87
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
88
89


90
def define_mnist_flags():
91
  flags_core.define_base()
92
93
  flags_core.define_performance(inter_op=True, intra_op=True,
                                num_parallel_calls=False)
94
95
96
97
98
99
100
101
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)
  flags_core.set_defaults(data_dir='/tmp/mnist_data',
                          model_dir='/tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40)


Asim Shankar's avatar
Asim Shankar committed
102
103
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
104
  model = create_model(params['data_format'])
105
106
107
108
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
109
  if mode == tf.estimator.ModeKeys.PREDICT:
110
111
112
113
114
115
116
117
118
119
120
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
121
  if mode == tf.estimator.ModeKeys.TRAIN:
122
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
123

124
    logits = model(image, training=True)
125
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
126
    accuracy = tf.metrics.accuracy(
127
        labels=labels, predictions=tf.argmax(logits, axis=1))
128
129
130
131

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
132
    tf.identity(accuracy[1], name='train_accuracy')
133
134

    # Save accuracy scalar to Tensorboard output.
135
    tf.summary.scalar('train_accuracy', accuracy[1])
136

137
138
139
140
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
141
  if mode == tf.estimator.ModeKeys.EVAL:
142
    logits = model(image, training=False)
143
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
144
145
146
147
148
149
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
150
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
151
        })
152
153


154
155
156
157
158
159
def run_mnist(flags_obj):
  """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
160
  model_helpers.apply_clean(flags_obj)
161
162
  model_function = model_fn

163
164
165
166
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)
167

168
  distribution_strategy = distribution_utils.get_distribution_strategy(
169
170
171
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_core.get_num_gpus(flags_obj),
      all_reduce_alg=flags_obj.all_reduce_alg)
172

173
174
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy, session_config=session_config)
175

176
  data_format = flags_obj.data_format
Asim Shankar's avatar
Asim Shankar committed
177
178
179
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
180
  mnist_classifier = tf.estimator.Estimator(
181
      model_fn=model_function,
182
      model_dir=flags_obj.model_dir,
183
      config=run_config,
Asim Shankar's avatar
Asim Shankar committed
184
      params={
185
          'data_format': data_format,
Asim Shankar's avatar
Asim Shankar committed
186
      })
187

188
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
189
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
190
191
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
192
193
194
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
195
196
    ds = dataset.train(flags_obj.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
Asim Shankar's avatar
Asim Shankar committed
197

198
199
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
200
    ds = ds.repeat(flags_obj.epochs_between_evals)
201
    return ds
202

Asim Shankar's avatar
Asim Shankar committed
203
  def eval_input_fn():
204
205
    return dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
206

207
208
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
209
210
      flags_obj.hooks, model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)
211
212

  # Train and evaluate model.
213
  for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
214
215
216
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
217

218
    if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
Asim Shankar's avatar
Asim Shankar committed
219
                                         eval_results['accuracy']):
220
221
      break

222
  # Export the model
223
  if flags_obj.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
224
225
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
226
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
227
    })
228
229
    mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
                                       strip_default_attrs=True)
230
231


232
233
234
235
def main(_):
  run_mnist(flags.FLAGS)


236
if __name__ == '__main__':
237
  tf.logging.set_verbosity(tf.logging.INFO)
238
239
  define_mnist_flags()
  absl_app.run(main)