Commit 807d6bde authored by Kathy Wu's avatar Kathy Wu
Browse files

Fixed cifar 10 tests

parent e5f88ad6
...@@ -97,8 +97,8 @@ def get_filenames(is_training, data_dir): ...@@ -97,8 +97,8 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')] return [os.path.join(data_dir, 'test_batch.bin')]
def parse_and_preprocess_record(raw_record, is_training): def parse_record(raw_record):
"""Parse and preprocess a CIFAR-10 image and label from a raw record.""" """Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number # Every record consists of a label followed by the image, with a fixed number
# of bytes for each. # of bytes for each.
label_bytes = 1 label_bytes = 1
...@@ -120,12 +120,6 @@ def parse_and_preprocess_record(raw_record, is_training): ...@@ -120,12 +120,6 @@ def parse_and_preprocess_record(raw_record, is_training):
# float32. # float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
if is_training:
image = train_preprocess_fn(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, tf.one_hot(label, _NUM_CLASSES) return image, tf.one_hot(label, _NUM_CLASSES)
...@@ -143,6 +137,18 @@ def train_preprocess_fn(image): ...@@ -143,6 +137,18 @@ def train_preprocess_fn(image):
return image return image
def parse_and_preprocess(record, is_training):
"""Parse and preprocess records in the CIFAR-10 dataset."""
image, label = parse_record(record)
if is_training:
image = train_preprocess_fn(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
...@@ -163,7 +169,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -163,7 +169,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.map( dataset = dataset.map(
lambda record: parse_and_preprocess_record(record, is_training)) lambda record: parse_and_preprocess(record, is_training))
dataset = dataset.prefetch(2 * batch_size) dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
......
...@@ -44,7 +44,7 @@ class BaseTest(tf.test.TestCase): ...@@ -44,7 +44,7 @@ class BaseTest(tf.test.TestCase):
data_file.close() data_file.close()
fake_dataset = cifar10_main.record_dataset(filename) fake_dataset = cifar10_main.record_dataset(filename)
fake_dataset = fake_dataset.map(cifar10_main.dataset_parser) fake_dataset = fake_dataset.map(cifar10_main.parse_record)
image, label = fake_dataset.make_one_shot_iterator().get_next() image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertEqual(label.get_shape().as_list(), [10]) self.assertEqual(label.get_shape().as_list(), [10])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment