Commit c6479e77 authored by Neal Wu's avatar Neal Wu
Browse files

Ensure that shuffle occurs before map

parent 6e52c271
...@@ -71,8 +71,6 @@ _NUM_IMAGES = { ...@@ -71,8 +71,6 @@ _NUM_IMAGES = {
'validation': 10000, 'validation': 10000,
} }
_SHUFFLE_BUFFER = 20000
def record_dataset(filenames): def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`.""" """Returns an input pipeline Dataset from `filenames`."""
...@@ -158,8 +156,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -158,8 +156,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance. Because CIFAR-10
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) # is a relatively small dataset, we choose to shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
dataset = dataset.map(parse_record) dataset = dataset.map(parse_record)
dataset = dataset.map( dataset = dataset.map(
......
...@@ -142,14 +142,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -142,14 +142,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_parallel_calls=5).prefetch(batch_size)
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together. # epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
......
...@@ -179,11 +179,12 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -179,11 +179,12 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
# Extract lines from input files using the Dataset API. # Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file) dataset = tf.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_parallel_calls=5)
if shuffle: if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together. # epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
...@@ -193,6 +194,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -193,6 +194,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
features, labels = iterator.get_next() features, labels = iterator.get_next()
return features, labels return features, labels
def main(unused_argv): def main(unused_argv):
# Clean up the model directory if present # Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True) shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment