Unverified Commit 5f0776a2 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Move dataset.map back to before dataset.shuffle in imagenet_main.py (#2731)

parent 21b48a85
...@@ -89,8 +89,8 @@ def filenames(is_training, data_dir): ...@@ -89,8 +89,8 @@ def filenames(is_training, data_dir):
for i in range(128)] for i in range(128)]
def dataset_parser(value, is_training): def record_parser(value, is_training):
"""Parse an Imagenet record from value.""" """Parse an ImageNet record from `value`."""
keys_to_features = { keys_to_features = {
'image/encoded': 'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''), tf.FixedLenFeature((), tf.string, default_value=''),
...@@ -134,23 +134,21 @@ def dataset_parser(value, is_training): ...@@ -134,23 +134,21 @@ def dataset_parser(value, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval.""" """Input function which provides batches for train or eval."""
dataset = tf.data.Dataset.from_tensor_slices( dataset = tf.data.Dataset.from_tensor_slices(filenames(is_training, data_dir))
filenames(is_training, data_dir))
if is_training: if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: record_parser(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together. # epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment