"docs/vscode:/vscode.git/clone" did not exist on "1344ebc8333df0bb5463a65b2d71f981659e071f"
Commit 6e52c271 authored by Neal Wu's avatar Neal Wu
Browse files

Separate parse_and_preprocess into two different dataset.map calls, which also keeps tests passing

parent 807d6bde
......@@ -108,23 +108,26 @@ def parse_record(raw_record):
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)
# The first byte represents the label, which we convert from uint8 to int32.
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, _NUM_CLASSES)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
[_DEPTH, _HEIGHT, _WIDTH])
depth_major = tf.reshape(
record_vector[label_bytes:record_bytes], [_DEPTH, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
return image, tf.one_hot(label, _NUM_CLASSES)
return image, label
def train_preprocess_fn(image):
"""Preprocess a single training image of layout [height, width, depth]."""
def preprocess_image(image, is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if is_training:
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
......@@ -134,19 +137,9 @@ def train_preprocess_fn(image):
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
return image
def parse_and_preprocess(record, is_training):
"""Parse and preprocess records in the CIFAR-10 dataset."""
image, label = parse_record(record)
if is_training:
image = train_preprocess_fn(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, label
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
......@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.map(parse_record)
dataset = dataset.map(
lambda record: parse_and_preprocess(record, is_training))
lambda image, label: (preprocess_image(image, is_training), label))
dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment