Commit f5a953c8 authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Make the input_fn return the dataset directly.

PiperOrigin-RevId: 189060074
parent bd855ed1
...@@ -69,15 +69,7 @@ def create_input_fn(file_pattern, ...@@ -69,15 +69,7 @@ def create_input_fn(file_pattern,
repeat=repeat, repeat=repeat,
use_tpu=use_tpu) use_tpu=use_tpu)
# We must use an initializable iterator, rather than a one-shot iterator, return dataset
# because the input pipeline contains a stateful table that requires
# initialization. We add the initializer to the TABLE_INITIALIZERS
# collection to ensure it is run during initialization.
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
inputs = iterator.get_next()
return inputs, inputs.pop("labels", None)
return input_fn return input_fn
...@@ -103,6 +95,14 @@ def create_model_fn(model_class, hparams, use_tpu=False): ...@@ -103,6 +95,14 @@ def create_model_fn(model_class, hparams, use_tpu=False):
if "batch_size" in params: if "batch_size" in params:
hparams.batch_size = params["batch_size"] hparams.batch_size = params["batch_size"]
# Allow labels to be passed in the features dictionary.
if "labels" in features:
if labels is not None and labels is not features["labels"]:
raise ValueError(
"Conflicting labels: features['labels'] = %s, labels = %s" %
(features["labels"], labels))
labels = features.pop("labels")
model = model_class(features, labels, hparams, mode) model = model_class(features, labels, hparams, mode)
model.build() model.build()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment