Unverified Commit 95aec51f authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #3334 from tensorflow/input_fn_dataset

Return datasets directly from input_fns for TF1.5
parents b7edd6e3 2ed6b9d9
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
The TensorFlow official models are a collection of example models that use TensorFlow's high-level APIs. They are intended to be well-maintained, tested, and kept up to date with the latest TensorFlow API. They should also be reasonably optimized for fast performance while still being easy to read. The TensorFlow official models are a collection of example models that use TensorFlow's high-level APIs. They are intended to be well-maintained, tested, and kept up to date with the latest TensorFlow API. They should also be reasonably optimized for fast performance while still being easy to read.
The master branch of the models are **in development**, and they target the [nightly binaries](https://github.com/tensorflow/tensorflow#installation) built from the [master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master). The master branch of the models are **in development**, and they target the [nightly binaries](https://github.com/tensorflow/tensorflow#installation) built from the [master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master). We aim to keep them backwards compatible with the latest release when possible (currently TensorFlow 1.5), but we cannot always guarantee compatibility.
**Stable versions** of the official models targeting releases of TensorFlow are available as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases). Model repository version numbers match the target TensorFlow release, such that [branch r1.4.0](https://github.com/tensorflow/models/tree/r1.4.0) and [release v1.4.0](https://github.com/tensorflow/models/releases/tag/v1.4.0) are compatible with [TensorFlow v1.4.0](https://github.com/tensorflow/tensorflow/releases/tag/v1.4.0). **Stable versions** of the official models targeting releases of TensorFlow are available as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases). Model repository version numbers match the target TensorFlow release, such that [branch r1.4.0](https://github.com/tensorflow/models/tree/r1.4.0) and [release v1.4.0](https://github.com/tensorflow/models/releases/tag/v1.4.0) are compatible with [TensorFlow v1.4.0](https://github.com/tensorflow/tensorflow/releases/tag/v1.4.0).
If you are on a version of TensorFlow earlier than v1.4, please [update your installation](https://www.tensorflow.org/install/). If you are on a version of TensorFlow earlier than 1.4, please [update your installation](https://www.tensorflow.org/install/).
--- ---
......
...@@ -142,7 +142,7 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -142,7 +142,7 @@ def validate_batch_size_for_multi_gpu(batch_size):
if not num_gpus: if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs ' raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.') 'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus remainder = batch_size % num_gpus
if remainder: if remainder:
err = ('When running with multiple GPUs, batch size ' err = ('When running with multiple GPUs, batch size '
...@@ -184,8 +184,7 @@ def main(unused_argv): ...@@ -184,8 +184,7 @@ def main(unused_argv):
ds = dataset.train(FLAGS.data_dir) ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat( ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs) FLAGS.train_epochs)
(images, labels) = ds.make_one_shot_iterator().get_next() return ds
return (images, labels)
# Set up training hook that logs the training accuracy every 100 steps. # Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = {'train_accuracy': 'train_accuracy'} tensors_to_log = {'train_accuracy': 'train_accuracy'}
......
...@@ -192,10 +192,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -192,10 +192,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
# epochs from blending together. # epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
return dataset
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
def main(unused_argv): def main(unused_argv):
......
...@@ -54,7 +54,9 @@ class BaseTest(tf.test.TestCase): ...@@ -54,7 +54,9 @@ class BaseTest(tf.test.TestCase):
temp_csv.write(TEST_INPUT) temp_csv.write(TEST_INPUT)
def test_input_fn(self): def test_input_fn(self):
features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1) dataset = wide_deep.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess: with tf.Session() as sess:
features, labels = sess.run((features, labels)) features, labels = sess.run((features, labels))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment