Commit fe3746e6 authored by Toby Boyd's avatar Toby Boyd
Browse files

Use AUTOTUNE, remove noop take, and comment fixes

parent eb370577
...@@ -68,36 +68,26 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -68,36 +68,26 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
""" """
# We prefetch a batch at a time, This can help smooth out the time taken to # Sets tf.data to AUTOTUNE, e.g. num_parallel_batches in map_and_batch.
# load input files as we go through shuffling and processing. options = tf.data.Options()
options.experimental_autotune = True
dataset = dataset.with_options(options)
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.prefetch(buffer_size=batch_size)
if is_training: if is_training:
# Shuffle the records. Note that we shuffle before repeating to ensure # Shuffles records before repeating to respect epoch boundaries.
# that the shuffling respects epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer) dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# If we are training over multiple epochs before evaluating, repeat the # Repeats the dataset for the number of epochs to train.
# dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
if is_training and num_gpus and examples_per_epoch: # Parses the raw records into images and labels.
total_examples = num_epochs * examples_per_epoch
# Force the number of batches to be divisible by the number of devices.
# This prevents some devices from receiving batches while others do not,
# which can lead to a lockup. This case will soon be handled directly by
# distribution strategies, at which point this .take() operation will no
# longer be needed.
total_batches = total_examples // batch_size // num_gpus * num_gpus
dataset.take(total_batches * batch_size)
# Parse the raw records into images and labels. Testing has shown that setting
# num_parallel_batches > 1 produces no improvement in throughput, since
# batch_size is almost always much greater than the number of CPU cores.
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.map_and_batch( tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training, dtype), lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size, batch_size=batch_size,
num_parallel_batches=1,
drop_remainder=False)) drop_remainder=False))
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment