Commit e5f88ad6 authored by Kathy Wu's avatar Kathy Wu
Browse files

Adding prefetch to dataset map functions, and combined map functions in cifar10_main

parent a97f5df7
...@@ -73,9 +73,6 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1): ...@@ -73,9 +73,6 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
dataset = tf.data.TFRecordDataset([filename]) dataset = tf.data.TFRecordDataset([filename])
# Parse each example in the dataset
dataset = dataset.map(example_parser)
# Apply dataset transformations # Apply dataset transformations
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
...@@ -88,8 +85,7 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1): ...@@ -88,8 +85,7 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Map example_parser over dataset, and batch results by up to batch_size # Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map( dataset = dataset.map(example_parser).prefetch(batch_size)
example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next() images, labels = iterator.get_next()
......
...@@ -97,33 +97,39 @@ def get_filenames(is_training, data_dir): ...@@ -97,33 +97,39 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')] return [os.path.join(data_dir, 'test_batch.bin')]
def dataset_parser(value): def parse_and_preprocess_record(raw_record, is_training):
"""Parse a CIFAR-10 record from value.""" """Parse and preprocess a CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number # Every record consists of a label followed by the image, with a fixed number
# of bytes for each. # of bytes for each.
label_bytes = 1 label_bytes = 1
image_bytes = _HEIGHT * _WIDTH * _DEPTH image_bytes = _HEIGHT * _WIDTH * _DEPTH
record_bytes = label_bytes + image_bytes record_bytes = label_bytes + image_bytes
# Convert from a string to a vector of uint8 that is record_bytes long. # Convert bytes to a vector of uint8 that is record_bytes long.
raw_record = tf.decode_raw(value, tf.uint8) record_vector = tf.decode_raw(raw_record, tf.uint8)
# The first byte represents the label, which we convert from uint8 to int32. # The first byte represents the label, which we convert from uint8 to int32.
label = tf.cast(raw_record[0], tf.int32) label = tf.cast(record_vector[0], tf.int32)
# The remaining bytes after the label represent the image, which we reshape # The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width]. # from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(raw_record[label_bytes:record_bytes], depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
[_DEPTH, _HEIGHT, _WIDTH]) [_DEPTH, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as # Convert from [depth, height, width] to [height, width, depth], and cast as
# float32. # float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
if is_training:
image = train_preprocess_fn(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, tf.one_hot(label, _NUM_CLASSES) return image, tf.one_hot(label, _NUM_CLASSES)
def train_preprocess_fn(image, label): def train_preprocess_fn(image):
"""Preprocess a single training image of layout [height, width, depth].""" """Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side. # Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8) image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
...@@ -134,7 +140,7 @@ def train_preprocess_fn(image, label): ...@@ -134,7 +140,7 @@ def train_preprocess_fn(image, label):
# Randomly flip the image horizontally. # Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
return image, label return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...@@ -143,26 +149,22 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -143,26 +149,22 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
Args: Args:
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
Returns: Returns:
A tuple of images and labels. A tuple of images and labels.
""" """
dataset = record_dataset(get_filenames(is_training, data_dir)) dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser)
# For training, preprocess the image and shuffle.
if is_training: if is_training:
dataset = dataset.map(train_preprocess_fn)
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map( dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label)) lambda record: parse_and_preprocess_record(record, is_training))
dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together. # epochs from blending together.
......
...@@ -143,7 +143,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -143,7 +143,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: dataset_parser(value, is_training), dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_parallel_calls=5) num_parallel_calls=5).prefetch(batch_size)
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment