Commit e5f88ad6 authored by Kathy Wu's avatar Kathy Wu
Browse files

Adding prefetch to dataset map functions, and combined map functions in cifar10_main

parent a97f5df7
......@@ -73,9 +73,6 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
dataset = tf.data.TFRecordDataset([filename])
# Parse each example in the dataset
dataset = dataset.map(example_parser)
# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
......@@ -88,8 +85,7 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
dataset = dataset.repeat(num_epochs)
# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(
example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.map(example_parser).prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
......
......@@ -97,33 +97,39 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')]
def dataset_parser(value):
"""Parse a CIFAR-10 record from value."""
def parse_and_preprocess_record(raw_record, is_training):
"""Parse and preprocess a CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
image_bytes = _HEIGHT * _WIDTH * _DEPTH
record_bytes = label_bytes + image_bytes
# Convert from a string to a vector of uint8 that is record_bytes long.
raw_record = tf.decode_raw(value, tf.uint8)
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)
# The first byte represents the label, which we convert from uint8 to int32.
label = tf.cast(raw_record[0], tf.int32)
label = tf.cast(record_vector[0], tf.int32)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
[_DEPTH, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
if is_training:
image = train_preprocess_fn(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, tf.one_hot(label, _NUM_CLASSES)
def train_preprocess_fn(image, label):
def train_preprocess_fn(image):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
......@@ -134,7 +140,7 @@ def train_preprocess_fn(image, label):
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
return image, label
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
......@@ -143,26 +149,22 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number samples per batch.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
Returns:
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser)
# For training, preprocess the image and shuffle.
if is_training:
dataset = dataset.map(train_preprocess_fn)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label))
lambda record: parse_and_preprocess_record(record, is_training))
dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
......
......@@ -143,7 +143,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_parallel_calls=5)
num_parallel_calls=5).prefetch(batch_size)
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment