mnist_tpu.py 7.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST model training using TPUs.

This program demonstrates training of the convolutional neural network model
defined in mnist.py on Google Cloud TPUs (https://cloud.google.com/tpu/).

If you are not interested in TPUs, you should ignore this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

26
27
28
import os
import sys

29
30
31
32
# pylint: disable=g-bad-import-order
from absl import app as absl_app
import tensorflow as tf
# pylint: enable=g-bad-import-order
Karmel Allison's avatar
Karmel Allison committed
33

34
35
36
37
38
# For open source environment, add grandparent directory for import
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.path[0]))))

from official.mnist import dataset  # pylint: disable=wrong-import-position
from official.mnist import mnist  # pylint: disable=wrong-import-position
39

40
# Cloud TPU Cluster Resolver flags
41
tf.flags.DEFINE_string(
42
43
44
45
    "tpu", default=None,
    help="The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")
46
47
tf.flags.DEFINE_string(
    "tpu_zone", default=None,
48
49
50
    help="[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")
51
tf.flags.DEFINE_string(
52
53
54
55
    "gcp_project", default=None,
    help="[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")
56

Neal Wu's avatar
Neal Wu committed
57
# Model specific parameters
58
59
60
tf.flags.DEFINE_string("data_dir", "",
                       "Path to directory containing the MNIST dataset")
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
Asim Shankar's avatar
Asim Shankar committed
61
tf.flags.DEFINE_integer("batch_size", 1024,
62
63
64
65
66
67
68
69
70
                        "Mini-batch size for the training. Note that this "
                        "is the global batch size and not the per-shard batch.")
tf.flags.DEFINE_integer("train_steps", 1000, "Total number of training steps.")
tf.flags.DEFINE_integer("eval_steps", 0,
                        "Total number of evaluation steps. If `0`, evaluation "
                        "after training is skipped.")
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")

tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
Aman Gupta's avatar
Aman Gupta committed
71
tf.flags.DEFINE_bool("enable_predict", True, "Do some predictions at the end")
72
73
74
75
76
77
78
79
80
tf.flags.DEFINE_integer("iterations", 50,
                        "Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")

FLAGS = tf.flags.FLAGS


def metric_fn(labels, logits):
  accuracy = tf.metrics.accuracy(
81
      labels=labels, predictions=tf.argmax(logits, axis=1))
82
83
84
85
  return {"accuracy": accuracy}


def model_fn(features, labels, mode, params):
86
87
  """model_fn constructs the ML model used to predict handwritten digits."""

88
89
90
91
92
  del params
  image = features
  if isinstance(image, dict):
    image = features["image"]

93
  model = mnist.create_model("channels_last")
94

Aman Gupta's avatar
Aman Gupta committed
95
96
97
98
99
100
  if mode == tf.estimator.ModeKeys.PREDICT:
    logits = model(image, training=False)
    predictions = {
        'class_ids': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
101
    return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)
Aman Gupta's avatar
Aman Gupta committed
102

103
  logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
104
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

  if mode == tf.estimator.ModeKeys.TRAIN:
    learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        tf.train.get_global_step(),
        decay_steps=100000,
        decay_rate=0.96)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    if FLAGS.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_global_step()))

  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))


def train_input_fn(params):
126
  """train_input_fn defines the input pipeline used for training."""
127
128
129
130
131
132
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  # Retrieves the batch size for the current shard. The # of shards is
  # computed according to the input pipeline deployment. See
  # `tf.contrib.tpu.RunConfig` for details.
  ds = dataset.train(data_dir).cache().repeat().shuffle(
133
134
      buffer_size=50000).batch(batch_size, drop_remainder=True)
  return ds
135
136
137
138
139


def eval_input_fn(params):
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
140
141
  ds = dataset.test(data_dir).batch(batch_size, drop_remainder=True)
  return ds
142
143


Aman Gupta's avatar
Aman Gupta committed
144
145
146
147
148
149
150
151
def predict_input_fn(params):
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  # Take out top 10 samples from test data to make the predictions.
  ds = dataset.test(data_dir).take(10).batch(batch_size)
  return ds


152
153
154
155
def main(argv):
  del argv  # Unused.
  tf.logging.set_verbosity(tf.logging.INFO)

156
157
158
159
160
  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu,
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project
  )
Neal Wu's avatar
Neal Wu committed
161

162
  run_config = tf.contrib.tpu.RunConfig(
163
      cluster=tpu_cluster_resolver,
164
165
166
167
168
169
170
171
172
173
174
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
  )

  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size,
Aman Gupta's avatar
Aman Gupta committed
175
      predict_batch_size=FLAGS.batch_size,
176
177
178
179
180
181
182
183
      params={"data_dir": FLAGS.data_dir},
      config=run_config)
  # TPUEstimator.train *requires* a max_steps argument.
  estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
  # TPUEstimator.evaluate *requires* a steps argument.
  # Note that the number of examples used during evaluation is
  # --eval_steps * --batch_size.
  # So if you change --batch_size then change --eval_steps too.
Asim Shankar's avatar
Asim Shankar committed
184
185
  if FLAGS.eval_steps:
    estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps)
186

187
  # Run prediction on top few samples of test data.
Aman Gupta's avatar
Aman Gupta committed
188
189
190
191
  if FLAGS.enable_predict:
    predictions = estimator.predict(input_fn=predict_input_fn)

    for pred_dict in predictions:
192
      template = ('Prediction is "{}" ({:.1f}%).')
Aman Gupta's avatar
Aman Gupta committed
193
194
195
196

      class_id = pred_dict['class_ids']
      probability = pred_dict['probabilities'][class_id]

197
      print(template.format(class_id, 100 * probability))
Aman Gupta's avatar
Aman Gupta committed
198

199
200

if __name__ == "__main__":
201
  absl_app.run()