"research/astronet/astrowavenet/__init__.py" did not exist on "8b200ebbf248c7d0424fa29b78121eac1c9b23ab"
mnist.py 8.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
21
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
22
import tensorflow as tf  # pylint: disable=g-bad-import-order
23

24
from official.mnist import dataset
25
from official.utils.flags import core as flags_core
26
from official.utils.logs import hooks_helper
27
from official.utils.misc import distribution_utils
28
from official.utils.misc import model_helpers
29

30

31
LEARNING_RATE = 1e-4
32

Karmel Allison's avatar
Karmel Allison committed
33

34
def create_model(data_format):
Asim Shankar's avatar
Asim Shankar committed
35
  """Model to recognize digits in the MNIST dataset.
Asim Shankar's avatar
Asim Shankar committed
36
37
38
39
40
41

  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py

42
43
44
  But uses the tf.keras API.

  Args:
Asim Shankar's avatar
Asim Shankar committed
45
46
47
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
48
      https://www.tensorflow.org/performance/performance_guide#data_formats
Asim Shankar's avatar
Asim Shankar committed
49

50
51
52
53
54
55
56
57
58
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

Asim Shankar's avatar
Asim Shankar committed
59
60
61
  l = tf.keras.layers
  max_pool = l.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
62
63
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
Asim Shankar's avatar
Asim Shankar committed
64
65
  return tf.keras.Sequential(
      [
66
67
68
          l.Reshape(
              target_shape=input_shape,
              input_shape=(28 * 28,)),
Asim Shankar's avatar
Asim Shankar committed
69
70
71
72
73
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
74
75
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
76
77
78
79
80
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
81
82
              activation=tf.nn.relu),
          max_pool,
Asim Shankar's avatar
Asim Shankar committed
83
84
85
86
87
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          l.Dropout(0.4),
          l.Dense(10)
      ])
Asim Shankar's avatar
Asim Shankar committed
88
89


90
def define_mnist_flags():
91
  """Defines flags for mnist."""
92
  flags_core.define_base(clean=True, train_epochs=True,
93
94
95
                         epochs_between_evals=True, stop_threshold=True,
                         num_gpu=True, hooks=True, export_dir=True,
                         distribution_strategy=True)
96
  flags_core.define_performance(inter_op=True, intra_op=True,
97
98
                                num_parallel_calls=False,
                                all_reduce_alg=True)
99
100
101
102
103
104
105
106
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)
  flags_core.set_defaults(data_dir='/tmp/mnist_data',
                          model_dir='/tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40)


Asim Shankar's avatar
Asim Shankar committed
107
108
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
109
  model = create_model(params['data_format'])
110
111
112
113
  image = features
  if isinstance(image, dict):
    image = features['image']

Asim Shankar's avatar
Asim Shankar committed
114
  if mode == tf.estimator.ModeKeys.PREDICT:
115
116
117
118
119
120
121
122
123
124
125
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
Asim Shankar's avatar
Asim Shankar committed
126
  if mode == tf.estimator.ModeKeys.TRAIN:
127
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
128

129
    logits = model(image, training=True)
130
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
131
    accuracy = tf.metrics.accuracy(
132
        labels=labels, predictions=tf.argmax(logits, axis=1))
133
134
135
136

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
137
    tf.identity(accuracy[1], name='train_accuracy')
138
139

    # Save accuracy scalar to Tensorboard output.
140
    tf.summary.scalar('train_accuracy', accuracy[1])
141

142
143
144
145
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
Asim Shankar's avatar
Asim Shankar committed
146
  if mode == tf.estimator.ModeKeys.EVAL:
147
    logits = model(image, training=False)
148
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
149
150
151
152
153
154
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
Asim Shankar's avatar
Asim Shankar committed
155
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
156
        })
157
158


159
160
161
162
163
164
def run_mnist(flags_obj):
  """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
165
  model_helpers.apply_clean(flags_obj)
166
167
  model_function = model_fn

168
169
170
171
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)
172

173
  distribution_strategy = distribution_utils.get_distribution_strategy(
174
175
176
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_core.get_num_gpus(flags_obj),
      all_reduce_alg=flags_obj.all_reduce_alg)
177

178
179
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy, session_config=session_config)
180

181
  data_format = flags_obj.data_format
Asim Shankar's avatar
Asim Shankar committed
182
183
184
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
185
  mnist_classifier = tf.estimator.Estimator(
186
      model_fn=model_function,
187
      model_dir=flags_obj.model_dir,
188
      config=run_config,
Asim Shankar's avatar
Asim Shankar committed
189
      params={
190
          'data_format': data_format,
Asim Shankar's avatar
Asim Shankar committed
191
      })
192

193
  # Set up training and evaluation input functions.
Asim Shankar's avatar
Asim Shankar committed
194
  def train_input_fn():
Karmel Allison's avatar
Karmel Allison committed
195
196
    """Prepare data for training."""

Asim Shankar's avatar
Asim Shankar committed
197
198
199
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
200
201
    ds = dataset.train(flags_obj.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
Asim Shankar's avatar
Asim Shankar committed
202

203
204
    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
205
    ds = ds.repeat(flags_obj.epochs_between_evals)
206
    return ds
207

Asim Shankar's avatar
Asim Shankar committed
208
  def eval_input_fn():
209
210
    return dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size).make_one_shot_iterator().get_next()
Asim Shankar's avatar
Asim Shankar committed
211

212
213
  # Set up hook that outputs training logs every 100 steps.
  train_hooks = hooks_helper.get_train_hooks(
214
215
      flags_obj.hooks, model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)
216
217

  # Train and evaluate model.
218
  for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
219
220
221
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print('\nEvaluation results:\n\t%s\n' % eval_results)
222

223
    if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
Asim Shankar's avatar
Asim Shankar committed
224
                                         eval_results['accuracy']):
225
226
      break

227
  # Export the model
228
  if flags_obj.export_dir is not None:
Asim Shankar's avatar
Asim Shankar committed
229
230
    image = tf.placeholder(tf.float32, [None, 28, 28])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
231
        'image': image,
Asim Shankar's avatar
Asim Shankar committed
232
    })
233
234
    mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
                                       strip_default_attrs=True)
235
236


237
238
239
240
def main(_):
  run_mnist(flags.FLAGS)


241
if __name__ == '__main__':
242
  tf.logging.set_verbosity(tf.logging.INFO)
243
244
  define_mnist_flags()
  absl_app.run(main)