mnist.py 8.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
21
from absl import app as absl_app
from absl import flags
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
from six.moves import range
Karmel Allison's avatar
Karmel Allison committed
23
import tensorflow as tf  # pylint: disable=g-bad-import-order
24

25
from official.mnist import dataset
26
from official.utils.flags import core as flags_core
27
from official.utils.logs import hooks_helper
28
from official.utils.misc import distribution_utils
29
from official.utils.misc import model_helpers
30

31

32
LEARNING_RATE = 1e-4
33

Karmel Allison's avatar
Karmel Allison committed
34

35
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
36
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
37
38
39
40
41
42

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

43
44
45
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
46
47
48
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
49
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
50

51
52
53
54
55
56
57
58
59
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
60
61
62
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
63
64
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
Asim Shankar's avatar
Asim Shankar committed
65
66
  return tf.keras.Sequential(
      [
67
68
69
          l.Reshape(
              target_shape=input_shape,
              input_shape=(28 * 28,)),
Asim Shankar's avatar
Asim Shankar committed
70
71
72
73
74
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
75
76
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
77
78
79
80
81
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
82
83
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
84
85
86
87
88
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
89
90


91
def define_mnist_flags():
92
  """Defines flags for mnist."""
93
  flags_core.define_base(clean=True, train_epochs=True,
94
95
96
                         epochs_between_evals=True, stop_threshold=True,
                         num_gpu=True, hooks=True, export_dir=True,
                         distribution_strategy=True)
97
  flags_core.define_performance(inter_op=True, intra_op=True,
98
99
                                num_parallel_calls=False,
                                all_reduce_alg=True)
100
101
102
103
104
105
106
107
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)
  flags_core.set_defaults(data_dir='/tmp/mnist_data',
                          model_dir='/tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40)


Asim Shankar's avatar
Asim Shankar committed
108
109
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
110
  model = create_model(params['data_format'])
111
112
113
114
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
115
  if mode == tf.estimator.ModeKeys.PREDICT:
116
117
118
119
120
121
122
123
124
125
126
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
127
  if mode == tf.estimator.ModeKeys.TRAIN:
128
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
129

130
    logits = model(image, training=True)
131
132
133
    loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(labels=labels,
                                                            logits=logits)
    accuracy = tf.compat.v1.metrics.accuracy(
134
        labels=labels, predictions=tf.argmax(logits, axis=1))
135
136
137
138

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
139
    tf.identity(accuracy[1], name='train_accuracy')
140
141

    # Save accuracy scalar to Tensorboard output.
142
    tf.summary.scalar('train_accuracy', accuracy[1])
143

144
145
146
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
minoring's avatar
minoring committed
147
148
        train_op=optimizer.minimize(
            loss,
149
            tf.compat.v1.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
150
  if mode == tf.estimator.ModeKeys.EVAL:
151
    logits = model(image, training=False)
152
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
153
154
155
156
157
158
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
159
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
160
        })
161
162


163
164
165
166
167
168
def run_mnist(flags_obj):
  """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
169
  model_helpers.apply_clean(flags_obj)
170
171
  model_function = model_fn

172
  session_config = tf.compat.v1.ConfigProto(
173
174
175
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)
176

177
  distribution_strategy = distribution_utils.get_distribution_strategy(
178
179
180
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_core.get_num_gpus(flags_obj),
      all_reduce_alg=flags_obj.all_reduce_alg)
181

182
183
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy, session_config=session_config)
184

185
  data_format = flags_obj.data_format
Asim Shankar's avatar
Asim Shankar committed
186
187
188
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
189
  mnist_classifier = tf.estimator.Estimator(
190
      model_fn=model_function,
191
      model_dir=flags_obj.model_dir,
192
      config=run_config,
Asim Shankar's avatar
Asim Shankar committed
193
      params={
194
          'data_format': data_format,
Asim Shankar's avatar
Asim Shankar committed
195
      })
196

197
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
198
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
199
200
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
201
202
203
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
204
205
    ds = dataset.train(flags_obj.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
Asim Shankar's avatar
Asim Shankar committed
206

207
208
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
209
    ds = ds.repeat(flags_obj.epochs_between_evals)
210
    return ds
211

Asim Shankar's avatar
Asim Shankar committed
212
  def eval_input_fn():
213
214
    return dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
215

216
217
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
218
219
      flags_obj.hooks, model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)
220
221

  # Train and evaluate model.
222
  for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
223
224
225
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
226

227
    if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
Asim Shankar's avatar
Asim Shankar committed
228
                                         eval_results['accuracy']):
229
230
      break

231
  # Export the model
232
  if flags_obj.export_dir is not None:
233
    image = tf.compat.v1.placeholder(tf.float32, [None, 28, 28])
Asim Shankar's avatar
Asim Shankar committed
234
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
235
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
236
    })
237
238
    mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
                                       strip_default_attrs=True)
239
240


241
242
243
244
def main(_):
  run_mnist(flags.FLAGS)


245
if __name__ == '__main__':
246
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
247
248
  define_mnist_flags()
  absl_app.run(main)