Commit 6e52c271 authored by Neal Wu's avatar Neal Wu
Browse files

Separate parse_and_preprocess into two different dataset.map calls, which also keeps tests passing

parent 807d6bde
...@@ -108,45 +108,38 @@ def parse_record(raw_record): ...@@ -108,45 +108,38 @@ def parse_record(raw_record):
# Convert bytes to a vector of uint8 that is record_bytes long. # Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8) record_vector = tf.decode_raw(raw_record, tf.uint8)
# The first byte represents the label, which we convert from uint8 to int32. # The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32) label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, _NUM_CLASSES)
# The remaining bytes after the label represent the image, which we reshape # The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width]. # from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[label_bytes:record_bytes], depth_major = tf.reshape(
[_DEPTH, _HEIGHT, _WIDTH]) record_vector[label_bytes:record_bytes], [_DEPTH, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as # Convert from [depth, height, width] to [height, width, depth], and cast as
# float32. # float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
return image, tf.one_hot(label, _NUM_CLASSES) return image, label
def train_preprocess_fn(image):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
return image
def preprocess_image(image, is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if is_training:
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
def parse_and_preprocess(record, is_training): # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
"""Parse and preprocess records in the CIFAR-10 dataset.""" image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
image, label = parse_record(record)
if is_training: # Randomly flip the image horizontally.
image = train_preprocess_fn(image) image = tf.image.random_flip_left_right(image)
# Subtract off the mean and divide by the variance of the pixels. # Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image) image = tf.image.per_image_standardization(image)
return image, label return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.map(parse_record)
dataset = dataset.map( dataset = dataset.map(
lambda record: parse_and_preprocess(record, is_training)) lambda image, label: (preprocess_image(image, is_training), label))
dataset = dataset.prefetch(2 * batch_size) dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment