Commit aead3912 authored by Derek Murray's avatar Derek Murray Committed by Taylor Robie
Browse files

Update mnist_tpu.py with recommended tf.data APIs (#5853)

1. `tf.contrib.data.batch_and_drop_remainder()` has been deprecated for a while now.
2. `Dataset.make_one_shot_iterator()` is no longer necessary in input functions, and avoiding it lead to better performance on TPUs. (It will also be removed in TF 2.0.)
parent cc0ad1cb
......@@ -127,19 +127,15 @@ def train_input_fn(params):
# computed according to the input pipeline deployment. See
# `tf.contrib.tpu.RunConfig` for details.
ds = dataset.train(data_dir).cache().repeat().shuffle(
buffer_size=50000).apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
images, labels = ds.make_one_shot_iterator().get_next()
return images, labels
buffer_size=50000).batch(batch_size, drop_remainder=True)
return ds
def eval_input_fn(params):
batch_size = params["batch_size"]
data_dir = params["data_dir"]
ds = dataset.test(data_dir).apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
images, labels = ds.make_one_shot_iterator().get_next()
return images, labels
ds = dataset.test(data_dir).batch(batch_size, drop_remainder=True)
return ds
def predict_input_fn(params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment