"vscode:/vscode.git/clone" did not exist on "86aeec1005cd0c7c00cbfeb2d99cc9ef0a63f9df"
Unverified Commit 2ea91716 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Refactoring input function and allowing for more parallel processing (#3332)

* Refactoring input function

* Updating name of variable
parent cacb4c6d
......@@ -29,6 +29,8 @@ _HEIGHT = 32
_WIDTH = 32
_NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS
# The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
_NUM_CLASSES = 10
_NUM_DATA_FILES = 5
......@@ -41,12 +43,6 @@ _NUM_IMAGES = {
###############################################################################
# Data processing
###############################################################################
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = _DEFAULT_IMAGE_BYTES + 1
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(is_training, data_dir):
"""Returns a list of filenames."""
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
......@@ -64,13 +60,8 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')]
def parse_record(raw_record):
def parse_record(raw_record, is_training):
"""Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
record_bytes = label_bytes + _DEFAULT_IMAGE_BYTES
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)
......@@ -81,13 +72,15 @@ def parse_record(raw_record):
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
[_NUM_CHANNELS, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
image = preprocess_image(image, is_training)
return image, label
......@@ -109,7 +102,8 @@ def preprocess_image(image, is_training):
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
......@@ -117,35 +111,18 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
Returns:
A tuple of images and labels.
A dataset that can be used for iteration.
"""
dataset = record_dataset(get_filenames(is_training, data_dir))
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because CIFAR-10
# is a relatively small dataset, we choose to shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
dataset = dataset.map(parse_record)
dataset = dataset.map(
lambda image, label: (preprocess_image(image, is_training), label))
dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
return images, labels
return resnet.process_record_dataset(dataset, is_training, batch_size,
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls)
###############################################################################
......
......@@ -43,8 +43,10 @@ class BaseTest(tf.test.TestCase):
data_file.write(fake_data)
data_file.close()
fake_dataset = cifar10_main.record_dataset(filename)
fake_dataset = fake_dataset.map(cifar10_main.parse_record)
fake_dataset = tf.data.FixedLengthRecordDataset(
filename, cifar10_main._RECORD_BYTES)
fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertEqual(label.get_shape().as_list(), [10])
......@@ -57,7 +59,7 @@ class BaseTest(tf.test.TestCase):
for row in image:
for pixel in row:
self.assertAllEqual(pixel, np.array([0, 1, 2]))
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def input_fn(self):
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3])
......
......@@ -35,19 +35,19 @@ _NUM_IMAGES = {
'validation': 50000,
}
_FILE_SHUFFLE_BUFFER = 1024
_NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 1500
###############################################################################
# Data processing
###############################################################################
def filenames(is_training, data_dir):
def get_filenames(is_training, data_dir):
"""Return filenames for dataset."""
if is_training:
return [
os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(1024)]
for i in range(_NUM_TRAIN_FILES)]
else:
return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
......@@ -97,32 +97,33 @@ def parse_record(raw_record, is_training):
return image, tf.one_hot(label, _NUM_CLASSES)
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval."""
dataset = tf.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir))
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1):
"""Input function which provides batches for train or eval.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
Returns:
A dataset that can be used for iteration.
"""
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
# Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
# Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: parse_record(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
return resnet.process_record_dataset(dataset, is_training, batch_size,
_SHUFFLE_BUFFER, parse_record, num_epochs, num_parallel_calls)
###############################################################################
......
......@@ -42,6 +42,57 @@ _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
################################################################################
# Functions for input processing.
################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1):
"""Given a Dataset with raw records, parse each record into images and labels,
and return an iterator over the records.
Args:
dataset: A Dataset representing raw records
is_training: A boolean denoting whether the input is for training.
batch_size: The number of samples per batch.
shuffle_buffer: The buffer size to use when shuffling records. A larger
value results in better randomness, but smaller values reduce startup
time and use less memory.
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
Returns:
Dataset of (image, label) pairs ready for iteration.
"""
# We prefetch a batch at a time, This can help smooth out the time taken to
# load input files as we go through shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size)
if is_training:
# Shuffle the records. Note that we shuffle before repeating to ensure
# that the shuffling respects epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# If we are training over multiple epochs before evaluating, repeat the
# dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs)
# Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path.
dataset = dataset.prefetch(1)
return dataset
################################################################################
# Functions building the ResNet model.
################################################################################
......@@ -494,15 +545,20 @@ def resnet_main(flags, model_function, input_function):
tensors=tensors_to_log, every_n_iter=100)
print('Starting a training cycle.')
classifier.train(
input_fn=lambda: input_function(
True, flags.data_dir, flags.batch_size, flags.epochs_per_eval),
hooks=[logging_hook])
def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size,
flags.epochs_per_eval, flags.num_parallel_calls)
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
print('Starting to evaluate.')
# Evaluate the model and print results
eval_results = classifier.evaluate(input_fn=lambda: input_function(
False, flags.data_dir, flags.batch_size))
def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls)
eval_results = classifier.evaluate(input_fn=input_fn_eval)
print(eval_results)
......@@ -516,6 +572,13 @@ class ResnetArgParser(argparse.ArgumentParser):
'--data_dir', type=str, default='/tmp/resnet_data',
help='The directory where the input data is stored.')
self.add_argument(
'--num_parallel_calls', type=int, default=5,
help='The number of records that are processed in parallel '
'during input processing. This can be optimized per data set but '
'for generally homogeneous data sets, should be approximately the '
'number of available CPU cores.')
self.add_argument(
'--model_dir', type=str, default='/tmp/resnet_model',
help='The directory where the model will be stored.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment