"vscode:/vscode.git/clone" did not exist on "0fdbf1bd285cf55d2d24f7659d6734aa3bd6ef70"
Commit aead3912 authored by Derek Murray's avatar Derek Murray Committed by Taylor Robie
Browse files

Update mnist_tpu.py with recommended tf.data APIs (#5853)

1. `tf.contrib.data.batch_and_drop_remainder()` has been deprecated for a while now.
2. `Dataset.make_one_shot_iterator()` is no longer necessary in input functions, and avoiding it lead to better performance on TPUs. (It will also be removed in TF 2.0.)
parent cc0ad1cb
...@@ -127,19 +127,15 @@ def train_input_fn(params): ...@@ -127,19 +127,15 @@ def train_input_fn(params):
# computed according to the input pipeline deployment. See # computed according to the input pipeline deployment. See
# `tf.contrib.tpu.RunConfig` for details. # `tf.contrib.tpu.RunConfig` for details.
ds = dataset.train(data_dir).cache().repeat().shuffle( ds = dataset.train(data_dir).cache().repeat().shuffle(
buffer_size=50000).apply( buffer_size=50000).batch(batch_size, drop_remainder=True)
tf.contrib.data.batch_and_drop_remainder(batch_size)) return ds
images, labels = ds.make_one_shot_iterator().get_next()
return images, labels
def eval_input_fn(params): def eval_input_fn(params):
batch_size = params["batch_size"] batch_size = params["batch_size"]
data_dir = params["data_dir"] data_dir = params["data_dir"]
ds = dataset.test(data_dir).apply( ds = dataset.test(data_dir).batch(batch_size, drop_remainder=True)
tf.contrib.data.batch_and_drop_remainder(batch_size)) return ds
images, labels = ds.make_one_shot_iterator().get_next()
return images, labels
def predict_input_fn(params): def predict_input_fn(params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment