"tools/vscode:/vscode.git/clone" did not exist on "9a336696104834d836812037f43489b0d36f51ea"
Commit 0f2bf200 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Using TFRecords instead of TextLineDataset

parent 6f6bc501
...@@ -27,8 +27,6 @@ import tensorflow as tf ...@@ -27,8 +27,6 @@ import tensorflow as tf
HEIGHT = 32 HEIGHT = 32
WIDTH = 32 WIDTH = 32
DEPTH = 3 DEPTH = 3
NUM_CLASSES = 10
class Cifar10DataSet(object): class Cifar10DataSet(object):
"""Cifar10 data set. """Cifar10 data set.
...@@ -36,45 +34,81 @@ class Cifar10DataSet(object): ...@@ -36,45 +34,81 @@ class Cifar10DataSet(object):
Described by http://www.cs.toronto.edu/~kriz/cifar.html. Described by http://www.cs.toronto.edu/~kriz/cifar.html.
""" """
def __init__(self, data_dir): def __init__(self, data_dir, subset='train', use_distortion=True):
self.data_dir = data_dir self.data_dir = data_dir
self.subset = subset
def read_all_data(self, subset='train'): self.use_distortion = use_distortion
"""Reads from data file and return images and labels in a numpy array."""
if subset == 'train': def get_filenames(self):
filenames = [ if self.subset == 'train':
os.path.join(self.data_dir, 'data_batch_%d' % i) return [
os.path.join(self.data_dir, 'data_batch_%d.tfrecords' % i)
for i in xrange(1, 5) for i in xrange(1, 5)
] ]
elif subset == 'validation': elif self.subset == 'validation':
filenames = [os.path.join(self.data_dir, 'data_batch_5')] return [os.path.join(self.data_dir, 'data_batch_5.tfrecords')]
elif subset == 'eval': elif self.subset == 'eval':
filenames = [os.path.join(self.data_dir, 'test_batch')] return [os.path.join(self.data_dir, 'test_batch.tfrecords')]
else: else:
raise ValueError('Invalid data subset "%s"' % subset) raise ValueError('Invalid data subset "%s"' % self.subset)
inputs = [] def parser(self, serialized_example):
for filename in filenames: """Parses a single tf.Example into image and label tensors."""
with tf.gfile.Open(filename, 'r') as f: # Dimensions of the images in the CIFAR-10 dataset.
inputs.append(cPickle.load(f)) # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
all_images = np.concatenate([each_input['data'] # input format.
for each_input in inputs]).astype(np.float32) features = tf.parse_single_example(
all_labels = np.concatenate([each_input['labels'] for each_input in inputs]) serialized_example,
return all_images, all_labels features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image"], tf.uint8)
image.set_shape([3*32*32])
# Reshape from [depth * height * width] to [depth, height, width].
image = tf.transpose(tf.reshape(image, [3, 32, 32]), [1, 2, 0])
label = tf.cast(features["label"], tf.int32)
@staticmethod # Custom preprocessing .
def preprocess(image, is_training, distortion): image = self.preprocess(image)
with tf.name_scope('preprocess'):
# Read image layout as flattened CHW. print(image, label)
image = tf.reshape(image, [DEPTH, HEIGHT, WIDTH]) return image, label
# Convert to NHWC layout, compatible with TF image preprocessing APIs
image = tf.transpose(image, [1, 2, 0]) def make_batch(self, batch_size):
if is_training and distortion: """Read the images and labels from 'filenames'."""
# Pad 4 pixels on each dimension of feature map, done in mini-batch filenames = self.get_filenames()
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40) record_bytes = (32 * 32 * 3) + 1
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH]) # Repeat infinitely.
image = tf.image.random_flip_left_right(image) dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
return image
# Parse records.
dataset = dataset.map(self.parser, num_threads=batch_size,
output_buffer_size=2 * batch_size)
# Potentially shuffle records.
if self.subset == 'train':
min_queue_examples = int(
Cifar10DataSet.num_examples_per_epoch(self.subset) * 0.4)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
# Batch it up.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
print(image_batch, label_batch)
return image_batch, label_batch
def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout."""
if self.subset == 'train' and self.use_distortion:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
image = tf.image.random_flip_left_right(image)
return image
@staticmethod @staticmethod
def num_examples_per_epoch(subset='train'): def num_examples_per_epoch(subset='train'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment