"scripts/release_notes/classify_prs.py" did not exist on "e17f5ea2d322f5eb3cb7cb14aab9849fba013c7c"
Commit dfe2a43f authored by Aman Gupta's avatar Aman Gupta
Browse files

Enabling Prediction in mnist_tpu.

Right now we don't have input data for prediction. So using top 10
entries of test data as input.
parent 27b4acd4
...@@ -65,6 +65,7 @@ tf.flags.DEFINE_integer("eval_steps", 0, ...@@ -65,6 +65,7 @@ tf.flags.DEFINE_integer("eval_steps", 0,
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.") tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")
tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs") tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
tf.flags.DEFINE_bool("enable_predict", True, "Do some predictions at the end")
tf.flags.DEFINE_integer("iterations", 50, tf.flags.DEFINE_integer("iterations", 50,
"Number of iterations per TPU training loop.") "Number of iterations per TPU training loop.")
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).") tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")
...@@ -82,13 +83,24 @@ def model_fn(features, labels, mode, params): ...@@ -82,13 +83,24 @@ def model_fn(features, labels, mode, params):
"""model_fn constructs the ML model used to predict handwritten digits.""" """model_fn constructs the ML model used to predict handwritten digits."""
del params del params
if mode == tf.estimator.ModeKeys.PREDICT:
raise RuntimeError("mode {} is not supported yet".format(mode))
image = features image = features
if isinstance(image, dict): if isinstance(image, dict):
image = features["image"] image = features["image"]
model = mnist.create_model("channels_last") model = mnist.create_model("channels_last")
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'class_ids': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
'label': labels,
}
return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
...@@ -134,6 +146,14 @@ def eval_input_fn(params): ...@@ -134,6 +146,14 @@ def eval_input_fn(params):
return images, labels return images, labels
def predict_input_fn(params):
batch_size = params["batch_size"]
data_dir = params["data_dir"]
# Take out top 10 samples from test data to make the predictions.
ds = dataset.test(data_dir).take(10).batch(batch_size)
return ds
def main(argv): def main(argv):
del argv # Unused. del argv # Unused.
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -157,6 +177,7 @@ def main(argv): ...@@ -157,6 +177,7 @@ def main(argv):
use_tpu=FLAGS.use_tpu, use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size, train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size,
predict_batch_size=FLAGS.batch_size,
params={"data_dir": FLAGS.data_dir}, params={"data_dir": FLAGS.data_dir},
config=run_config) config=run_config)
# TPUEstimator.train *requires* a max_steps argument. # TPUEstimator.train *requires* a max_steps argument.
...@@ -168,6 +189,19 @@ def main(argv): ...@@ -168,6 +189,19 @@ def main(argv):
if FLAGS.eval_steps: if FLAGS.eval_steps:
estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps) estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps)
# Run prediction on the test data.
if FLAGS.enable_predict:
predictions = estimator.predict(input_fn=predict_input_fn)
for pred_dict in predictions:
template = ('Prediction is "{}" ({:.1f}%), expected "{}"')
class_id = pred_dict['class_ids']
probability = pred_dict['probabilities'][class_id]
expected_label = pred_dict['label']
print(template.format(class_id, 100 * probability, expected_label))
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment