Commit 0f2bf200 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Using TFRecords instead of TextLineDataset

parent 6f6bc501
...@@ -27,8 +27,6 @@ import tensorflow as tf ...@@ -27,8 +27,6 @@ import tensorflow as tf
HEIGHT = 32 HEIGHT = 32
WIDTH = 32 WIDTH = 32
DEPTH = 3 DEPTH = 3
NUM_CLASSES = 10
class Cifar10DataSet(object): class Cifar10DataSet(object):
"""Cifar10 data set. """Cifar10 data set.
...@@ -36,40 +34,76 @@ class Cifar10DataSet(object): ...@@ -36,40 +34,76 @@ class Cifar10DataSet(object):
Described by http://www.cs.toronto.edu/~kriz/cifar.html. Described by http://www.cs.toronto.edu/~kriz/cifar.html.
""" """
def __init__(self, data_dir): def __init__(self, data_dir, subset='train', use_distortion=True):
self.data_dir = data_dir self.data_dir = data_dir
self.subset = subset
self.use_distortion = use_distortion
def read_all_data(self, subset='train'): def get_filenames(self):
"""Reads from data file and return images and labels in a numpy array.""" if self.subset == 'train':
if subset == 'train': return [
filenames = [ os.path.join(self.data_dir, 'data_batch_%d.tfrecords' % i)
os.path.join(self.data_dir, 'data_batch_%d' % i)
for i in xrange(1, 5) for i in xrange(1, 5)
] ]
elif subset == 'validation': elif self.subset == 'validation':
filenames = [os.path.join(self.data_dir, 'data_batch_5')] return [os.path.join(self.data_dir, 'data_batch_5.tfrecords')]
elif subset == 'eval': elif self.subset == 'eval':
filenames = [os.path.join(self.data_dir, 'test_batch')] return [os.path.join(self.data_dir, 'test_batch.tfrecords')]
else: else:
raise ValueError('Invalid data subset "%s"' % subset) raise ValueError('Invalid data subset "%s"' % self.subset)
inputs = [] def parser(self, serialized_example):
for filename in filenames: """Parses a single tf.Example into image and label tensors."""
with tf.gfile.Open(filename, 'r') as f: # Dimensions of the images in the CIFAR-10 dataset.
inputs.append(cPickle.load(f)) # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
all_images = np.concatenate([each_input['data'] # input format.
for each_input in inputs]).astype(np.float32) features = tf.parse_single_example(
all_labels = np.concatenate([each_input['labels'] for each_input in inputs]) serialized_example,
return all_images, all_labels features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image"], tf.uint8)
image.set_shape([3*32*32])
@staticmethod # Reshape from [depth * height * width] to [depth, height, width].
def preprocess(image, is_training, distortion): image = tf.transpose(tf.reshape(image, [3, 32, 32]), [1, 2, 0])
with tf.name_scope('preprocess'): label = tf.cast(features["label"], tf.int32)
# Read image layout as flattened CHW.
image = tf.reshape(image, [DEPTH, HEIGHT, WIDTH]) # Custom preprocessing .
# Convert to NHWC layout, compatible with TF image preprocessing APIs image = self.preprocess(image)
image = tf.transpose(image, [1, 2, 0])
if is_training and distortion: print(image, label)
return image, label
def make_batch(self, batch_size):
"""Read the images and labels from 'filenames'."""
filenames = self.get_filenames()
record_bytes = (32 * 32 * 3) + 1
# Repeat infinitely.
dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
# Parse records.
dataset = dataset.map(self.parser, num_threads=batch_size,
output_buffer_size=2 * batch_size)
# Potentially shuffle records.
if self.subset == 'train':
min_queue_examples = int(
Cifar10DataSet.num_examples_per_epoch(self.subset) * 0.4)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
# Batch it up.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
print(image_batch, label_batch)
return image_batch, label_batch
def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout."""
if self.subset == 'train' and self.use_distortion:
# Pad 4 pixels on each dimension of feature map, done in mini-batch # Pad 4 pixels on each dimension of feature map, done in mini-batch
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40) image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH]) image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment