Unverified Commit 223e0f3a authored by Asim Shankar's avatar Asim Shankar Committed by GitHub
Browse files

Merge pull request #5379 from aman2930/master

Enabling prediction in mnist_tpu.
parents 1b6ca655 ea24314d
...@@ -65,6 +65,7 @@ tf.flags.DEFINE_integer("eval_steps", 0, ...@@ -65,6 +65,7 @@ tf.flags.DEFINE_integer("eval_steps", 0,
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.") tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")
tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs") tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
tf.flags.DEFINE_bool("enable_predict", True, "Do some predictions at the end")
tf.flags.DEFINE_integer("iterations", 50, tf.flags.DEFINE_integer("iterations", 50,
"Number of iterations per TPU training loop.") "Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).") tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")
...@@ -82,13 +83,20 @@ def model_fn(features, labels, mode, params): ...@@ -82,13 +83,20 @@ def model_fn(features, labels, mode, params):
"""model_fn constructs the ML model used to predict handwritten digits.""" """model_fn constructs the ML model used to predict handwritten digits."""
del params del params
if mode == tf.estimator.ModeKeys.PREDICT:
raise RuntimeError("mode {} is not supported yet".format(mode))
image = features image = features
if isinstance(image, dict): if isinstance(image, dict):
image = features["image"] image = features["image"]
model = mnist.create_model("channels_last") model = mnist.create_model("channels_last")
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'class_ids': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
...@@ -134,6 +142,14 @@ def eval_input_fn(params): ...@@ -134,6 +142,14 @@ def eval_input_fn(params):
return images, labels return images, labels
def predict_input_fn(params):
batch_size = params["batch_size"]
data_dir = params["data_dir"]
# Take out top 10 samples from test data to make the predictions.
ds = dataset.test(data_dir).take(10).batch(batch_size)
return ds
def main(argv): def main(argv):
del argv # Unused. del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -157,6 +173,7 @@ def main(argv): ...@@ -157,6 +173,7 @@ def main(argv):
use_tpu=FLAGS.use_tpu, use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size, train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size,
predict_batch_size=FLAGS.batch_size,
params={"data_dir": FLAGS.data_dir}, params={"data_dir": FLAGS.data_dir},
config=run_config) config=run_config)
# TPUEstimator.train *requires* a max_steps argument. # TPUEstimator.train *requires* a max_steps argument.
...@@ -168,6 +185,18 @@ def main(argv): ...@@ -168,6 +185,18 @@ def main(argv):
if FLAGS.eval_steps: if FLAGS.eval_steps:
estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps) estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps)
# Run prediction on top few samples of test data.
if FLAGS.enable_predict:
predictions = estimator.predict(input_fn=predict_input_fn)
for pred_dict in predictions:
template = ('Prediction is "{}" ({:.1f}%).')
class_id = pred_dict['class_ids']
probability = pred_dict['probabilities'][class_id]
print(template.format(class_id, 100 * probability))
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment