Commit 71e8adc7 authored by James Qin's avatar James Qin
Browse files

Replace in-memory DastaSet with FixedLengthRecordDataSet

In-memory DataSet initiated without feed_dict are saved in TF graph def,
causing exported graph.pbtxt over-bloated.
Change to FixedLengthRecordDataSet and does a few refactorization.
parent 3fb07dc0
...@@ -11,8 +11,8 @@ Code in this directory focuses on how to use TensorFlow Estimators to train and ...@@ -11,8 +11,8 @@ Code in this directory focuses on how to use TensorFlow Estimators to train and
2. Download the CIFAR-10 dataset. 2. Download the CIFAR-10 dataset.
```shell ```shell
curl -o cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
tar xzf cifar-10-python.tar.gz tar xzf cifar-10-binary.tar.gz
``` ```
<b>How to run:</b> <b>How to run:</b>
...@@ -20,26 +20,26 @@ tar xzf cifar-10-python.tar.gz ...@@ -20,26 +20,26 @@ tar xzf cifar-10-python.tar.gz
```shell ```shell
# After running the above commands, you should see the following in the folder # After running the above commands, you should see the following in the folder
# where the data is downloaded. # where the data is downloaded.
$ ls -R cifar-10-batches-py $ ls -R cifar-10-batches-bin
cifar-10-batches-py: cifar-10-batches-bin:
batches.meta data_batch_2 data_batch_4 readme.html batches.meta.txt data_batch_1.bin data_batch_2.bin data_batch_3.bin
data_batch_1 data_batch_3 data_batch_5 test_batch data_batch_4.bin data_batch_5.bin readme.html test_batch.bin
# Run the model on CPU only. After training, it runs the evaluation. # Run the model on CPU only. After training, it runs the evaluation.
$ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-py \ $ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-bin \
--model_dir=/tmp/resnet_model \ --model_dir=/tmp/cifar10 \
--is_cpu_ps=True \ --is_cpu_ps=True \
--num_gpus=0 \ --num_gpus=0 \
--train_steps=1000 --train_steps=1000
# Run the model on CPU and 2 CPUs. After training, it runs the evaluation. # Run the model on CPU and 2 CPUs. After training, it runs the evaluation.
$ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-py \ $ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-bin \
--model_dir=/tmp/resnet_model \ --model_dir=/tmp/cifar10 \
--is_cpu_ps=False \ --is_cpu_ps=False \
--force_gpu_compatible=True \ --force_gpu_compatible=True \
--num_gpus=2 \ --num_gpus=2 \
--train_steps=1000 --train_steps=1000
# There are more command line flags to play with; check cifar10_main.py for details. # There are more command line flags to play with; check cifar10_main.py for details.
``` ```
...@@ -27,8 +27,6 @@ import tensorflow as tf ...@@ -27,8 +27,6 @@ import tensorflow as tf
HEIGHT = 32 HEIGHT = 32
WIDTH = 32 WIDTH = 32
DEPTH = 3 DEPTH = 3
NUM_CLASSES = 10
class Cifar10DataSet(object): class Cifar10DataSet(object):
"""Cifar10 data set. """Cifar10 data set.
...@@ -36,45 +34,93 @@ class Cifar10DataSet(object): ...@@ -36,45 +34,93 @@ class Cifar10DataSet(object):
Described by http://www.cs.toronto.edu/~kriz/cifar.html. Described by http://www.cs.toronto.edu/~kriz/cifar.html.
""" """
def __init__(self, data_dir): def __init__(self, data_dir, subset='train', use_distortion=True):
self.data_dir = data_dir self.data_dir = data_dir
self.subset = subset
def read_all_data(self, subset='train'): self.use_distortion = use_distortion
"""Reads from data file and return images and labels in a numpy array."""
if subset == 'train': def get_filenames(self):
filenames = [ if self.subset == 'train':
os.path.join(self.data_dir, 'data_batch_%d' % i) return [
os.path.join(self.data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 5) for i in xrange(1, 5)
] ]
elif subset == 'validation': elif self.subset == 'validation':
filenames = [os.path.join(self.data_dir, 'data_batch_5')] return [os.path.join(self.data_dir, 'data_batch_5.bin')]
elif subset == 'eval': elif self.subset == 'eval':
filenames = [os.path.join(self.data_dir, 'test_batch')] return [os.path.join(self.data_dir, 'test_batch.bin')]
else: else:
raise ValueError('Invalid data subset "%s"' % subset) raise ValueError('Invalid data subset "%s"' % self.subset)
inputs = [] def make_batch(self, batch_size):
for filename in filenames: """Read the images and labels from 'filenames'."""
with tf.gfile.Open(filename, 'r') as f: filenames = self.get_filenames()
inputs.append(cPickle.load(f)) record_bytes = (32 * 32 * 3) + 1
all_images = np.concatenate([each_input['data'] # Repeat infinitely.
for each_input in inputs]).astype(np.float32) dataset = tf.contrib.data.FixedLengthRecordDataset(filenames,
all_labels = np.concatenate([each_input['labels'] for each_input in inputs]) record_bytes).repeat()
return all_images, all_labels # Parse records.
dataset = dataset.map(self.parser, num_threads=batch_size,
output_buffer_size=2 * batch_size)
# Potentially shuffle records.
if self.subset == 'train':
min_queue_examples = int(
Cifar10DataSet.num_examples_per_epoch(self.subset) * 0.4)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
# Batch it up.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
return image_batch, label_batch
@staticmethod def parser(self, value):
def preprocess(image, is_training, distortion): """Parse a Cifar10 record from value.
with tf.name_scope('preprocess'):
# Read image layout as flattened CHW. Output images are in [height, width, depth] layout.
image = tf.reshape(image, [DEPTH, HEIGHT, WIDTH]) """
# Convert to NHWC layout, compatible with TF image preprocessing APIs # Dimensions of the images in the CIFAR-10 dataset.
image = tf.transpose(image, [1, 2, 0]) # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
if is_training and distortion: # input format.
# Pad 4 pixels on each dimension of feature map, done in mini-batch label_bytes = 1
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40) image_bytes = HEIGHT * WIDTH * DEPTH
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH]) # Every record consists of a label followed by the image, with a
image = tf.image.random_flip_left_right(image) # fixed number of bytes for each.
return image record_bytes = label_bytes + image_bytes
# Convert from a string to a vector of uint8 that is record_bytes long.
record_as_bytes = tf.decode_raw(value, tf.uint8)
# The first bytes represent the label, which we convert from
# uint8->int32.
label = tf.cast(
tf.strided_slice(record_as_bytes, [0], [label_bytes]), tf.int32)
label.set_shape([1])
# The remaining bytes after the label represent the image, which
# we reshape from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(
tf.strided_slice(record_as_bytes, [label_bytes], [record_bytes]),
[3, 32, 32])
# Convert from [depth, height, width] to [height, width, depth].
# This puts data in a compatible layout with TF image preprocessing APIs.
image = tf.transpose(depth_major, [1, 2, 0])
# Do custom preprocessing here.
image = self.preprocess(image)
return image, label
def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout."""
if self.subset == 'train' and self.use_distortion:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
image = tf.image.random_flip_left_right(image)
return image
@staticmethod @staticmethod
def num_examples_per_epoch(subset='train'): def num_examples_per_epoch(subset='train'):
......
...@@ -302,43 +302,16 @@ def input_fn(subset, num_shards): ...@@ -302,43 +302,16 @@ def input_fn(subset, num_shards):
Returns: Returns:
two lists of tensors for features and labels, each of num_shards length. two lists of tensors for features and labels, each of num_shards length.
""" """
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir) if subset == 'train':
is_training = (subset == 'train')
if is_training:
batch_size = FLAGS.train_batch_size batch_size = FLAGS.train_batch_size
else: elif subset == 'validate' or subset == 'eval':
batch_size = FLAGS.eval_batch_size batch_size = FLAGS.eval_batch_size
with tf.device('/cpu:0'), tf.name_scope('batching'): else:
# CPU loads all data from disk since there're only 60k 32*32 RGB images. raise ValueError('Subset must be one of \'train\', \'validate\' and \'eval\'')
all_images, all_labels = dataset.read_all_data(subset) with tf.device('/cpu:0'):
dataset = tf.contrib.data.Dataset.from_tensor_slices( use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
(all_images, all_labels)) dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion)
dataset = dataset.map( image_batch, label_batch = dataset.make_batch(batch_size)
lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.int32)),
num_threads=2,
output_buffer_size=batch_size)
# Image preprocessing.
def _preprocess(image, label):
# If GPU is available, NHWC to NCHW transpose is done in ResNetCifar10
# class, not included in preprocessing.
return cifar10.Cifar10DataSet.preprocess(
image, is_training, FLAGS.use_distortion_for_training), label
dataset = dataset.map(
_preprocess, num_threads=batch_size, output_buffer_size=2 * batch_size)
# Repeat infinitely.
dataset = dataset.repeat()
if is_training:
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(
cifar10.Cifar10DataSet.num_examples_per_epoch(subset) *
min_fraction_of_examples_in_queue)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
if num_shards <= 1: if num_shards <= 1:
# No GPU available or only 1 GPU. # No GPU available or only 1 GPU.
return [image_batch], [label_batch] return [image_batch], [label_batch]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment