mnist_tpu.py 7.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST model training using TPUs.

This program demonstrates training of the convolutional neural network model
defined in mnist.py on Google Cloud TPUs (https://cloud.google.com/tpu/).

If you are not interested in TPUs, you should ignore this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

26
27
28
import os
import sys

Karmel Allison's avatar
Karmel Allison committed
29
30
import tensorflow as tf  # pylint: disable=g-bad-import-order

31
32
33
34
35
# For open source environment, add grandparent directory for import
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.path[0]))))

from official.mnist import dataset  # pylint: disable=wrong-import-position
from official.mnist import mnist  # pylint: disable=wrong-import-position
36

37
# Cloud TPU Cluster Resolver flags
38
tf.flags.DEFINE_string(
39
40
41
42
    "tpu", default=None,
    help="The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")
43
44
tf.flags.DEFINE_string(
    "tpu_zone", default=None,
45
46
47
    help="[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")
48
tf.flags.DEFINE_string(
49
50
51
52
    "gcp_project", default=None,
    help="[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")
53

Neal Wu's avatar
Neal Wu committed
54
# Model specific parameters
55
56
57
tf.flags.DEFINE_string("data_dir", "",
                       "Path to directory containing the MNIST dataset")
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
Asim Shankar's avatar
Asim Shankar committed
58
tf.flags.DEFINE_integer("batch_size", 1024,
59
60
61
62
63
64
65
66
67
                        "Mini-batch size for the training. Note that this "
                        "is the global batch size and not the per-shard batch.")
tf.flags.DEFINE_integer("train_steps", 1000, "Total number of training steps.")
tf.flags.DEFINE_integer("eval_steps", 0,
                        "Total number of evaluation steps. If `0`, evaluation "
                        "after training is skipped.")
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")

tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
Aman Gupta's avatar
Aman Gupta committed
68
tf.flags.DEFINE_bool("enable_predict", True, "Do some predictions at the end")
69
70
71
72
73
74
75
76
77
tf.flags.DEFINE_integer("iterations", 50,
                        "Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")

FLAGS = tf.flags.FLAGS


def metric_fn(labels, logits):
  accuracy = tf.metrics.accuracy(
78
      labels=labels, predictions=tf.argmax(logits, axis=1))
79
80
81
82
  return {"accuracy": accuracy}


def model_fn(features, labels, mode, params):
83
84
  """model_fn constructs the ML model used to predict handwritten digits."""

85
86
87
88
89
  del params
  image = features
  if isinstance(image, dict):
    image = features["image"]

90
  model = mnist.create_model("channels_last")
91

Aman Gupta's avatar
Aman Gupta committed
92
93
94
95
96
97
98
  if mode == tf.estimator.ModeKeys.PREDICT:
    logits = model(image, training=False)
    predictions = {
        'class_ids': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
        'label': labels,
    }
99
    return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)
Aman Gupta's avatar
Aman Gupta committed
100

101
  logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
102
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

  if mode == tf.estimator.ModeKeys.TRAIN:
    learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        tf.train.get_global_step(),
        decay_steps=100000,
        decay_rate=0.96)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    if FLAGS.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_global_step()))

  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))


def train_input_fn(params):
124
  """train_input_fn defines the input pipeline used for training."""
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  # Retrieves the batch size for the current shard. The # of shards is
  # computed according to the input pipeline deployment. See
  # `tf.contrib.tpu.RunConfig` for details.
  ds = dataset.train(data_dir).cache().repeat().shuffle(
      buffer_size=50000).apply(
          tf.contrib.data.batch_and_drop_remainder(batch_size))
  images, labels = ds.make_one_shot_iterator().get_next()
  return images, labels


def eval_input_fn(params):
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  ds = dataset.test(data_dir).apply(
      tf.contrib.data.batch_and_drop_remainder(batch_size))
  images, labels = ds.make_one_shot_iterator().get_next()
  return images, labels


Aman Gupta's avatar
Aman Gupta committed
146
147
148
149
150
151
152
153
def predict_input_fn(params):
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  # Take out top 10 samples from test data to make the predictions.
  ds = dataset.test(data_dir).take(10).batch(batch_size)
  return ds


154
155
156
157
def main(argv):
  del argv  # Unused.
  tf.logging.set_verbosity(tf.logging.INFO)

158
159
160
161
162
  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu,
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project
  )
Neal Wu's avatar
Neal Wu committed
163

164
  run_config = tf.contrib.tpu.RunConfig(
165
      cluster=tpu_cluster_resolver,
166
167
168
169
170
171
172
173
174
175
176
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
  )

  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size,
Aman Gupta's avatar
Aman Gupta committed
177
      predict_batch_size=FLAGS.batch_size,
178
179
180
181
182
183
184
185
      params={"data_dir": FLAGS.data_dir},
      config=run_config)
  # TPUEstimator.train *requires* a max_steps argument.
  estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
  # TPUEstimator.evaluate *requires* a steps argument.
  # Note that the number of examples used during evaluation is
  # --eval_steps * --batch_size.
  # So if you change --batch_size then change --eval_steps too.
Asim Shankar's avatar
Asim Shankar committed
186
187
  if FLAGS.eval_steps:
    estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps)
188

Aman Gupta's avatar
Aman Gupta committed
189
190
191
192
193
194
195
196
197
198
199
200
201
  # Run prediction on the test data.
  if FLAGS.enable_predict:
    predictions = estimator.predict(input_fn=predict_input_fn)

    for pred_dict in predictions:
      template = ('Prediction is "{}" ({:.1f}%), expected "{}"')

      class_id = pred_dict['class_ids']
      probability = pred_dict['probabilities'][class_id]
      expected_label = pred_dict['label']

      print(template.format(class_id, 100 * probability, expected_label))

202
203
204

if __name__ == "__main__":
  tf.app.run()