Commit e8a7f352 authored by Neal Wu's avatar Neal Wu
Browse files

Return datasets directly from input_fns for TF1.5

parent cacb4c6d
......@@ -2,11 +2,11 @@
The TensorFlow official models are a collection of example models that use TensorFlow's high-level APIs. They are intended to be well-maintained, tested, and kept up to date with the latest TensorFlow API. They should also be reasonably optimized for fast performance while still being easy to read.
The master branch of the models are **in development**, and they target the [nightly binaries](https://github.com/tensorflow/tensorflow#installation) built from the [master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master).
The master branch of the models are **in development**, and they target the [nightly binaries](https://github.com/tensorflow/tensorflow#installation) built from the [master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master). We aim to keep them backwards compatible with the latest release when possible (currently TensorFlow 1.5), but we cannot always guarantee compatibility.
**Stable versions** of the official models targeting releases of TensorFlow are available as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases). Model repository version numbers match the target TensorFlow release, such that [branch r1.4.0](https://github.com/tensorflow/models/tree/r1.4.0) and [release v1.4.0](https://github.com/tensorflow/models/releases/tag/v1.4.0) are compatible with [TensorFlow v1.4.0](https://github.com/tensorflow/tensorflow/releases/tag/v1.4.0).
If you are on a version of TensorFlow earlier than v1.4, please [update your installation](https://www.tensorflow.org/install/).
If you are on a version of TensorFlow earlier than 1.4, please [update your installation](https://www.tensorflow.org/install/).
---
......
......@@ -142,7 +142,7 @@ def validate_batch_size_for_multi_gpu(batch_size):
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
......@@ -184,8 +184,7 @@ def main(unused_argv):
ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = ds.make_one_shot_iterator().get_next()
return (images, labels)
return ds
# Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = {'train_accuracy': 'train_accuracy'}
......
......@@ -192,10 +192,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return dataset
def main(unused_argv):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment