"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "b11e3b9be3246257e8206bf090bb2e6d980a7dd0"
Unverified Commit 2ea91716 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Refactoring input function and allowing for more parallel processing (#3332)

* Refactoring input function

* Updating name of variable
parent cacb4c6d
...@@ -29,6 +29,8 @@ _HEIGHT = 32 ...@@ -29,6 +29,8 @@ _HEIGHT = 32
_WIDTH = 32 _WIDTH = 32
_NUM_CHANNELS = 3 _NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS _DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS
# The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
_NUM_CLASSES = 10 _NUM_CLASSES = 10
_NUM_DATA_FILES = 5 _NUM_DATA_FILES = 5
...@@ -41,12 +43,6 @@ _NUM_IMAGES = { ...@@ -41,12 +43,6 @@ _NUM_IMAGES = {
############################################################################### ###############################################################################
# Data processing # Data processing
############################################################################### ###############################################################################
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = _DEFAULT_IMAGE_BYTES + 1
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(is_training, data_dir): def get_filenames(is_training, data_dir):
"""Returns a list of filenames.""" """Returns a list of filenames."""
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
...@@ -64,13 +60,8 @@ def get_filenames(is_training, data_dir): ...@@ -64,13 +60,8 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')] return [os.path.join(data_dir, 'test_batch.bin')]
def parse_record(raw_record): def parse_record(raw_record, is_training):
"""Parse CIFAR-10 image and label from a raw record.""" """Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
record_bytes = label_bytes + _DEFAULT_IMAGE_BYTES
# Convert bytes to a vector of uint8 that is record_bytes long. # Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8) record_vector = tf.decode_raw(raw_record, tf.uint8)
...@@ -81,13 +72,15 @@ def parse_record(raw_record): ...@@ -81,13 +72,15 @@ def parse_record(raw_record):
# The remaining bytes after the label represent the image, which we reshape # The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width]. # from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[label_bytes:record_bytes], depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
[_NUM_CHANNELS, _HEIGHT, _WIDTH]) [_NUM_CHANNELS, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as # Convert from [depth, height, width] to [height, width, depth], and cast as
# float32. # float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
image = preprocess_image(image, is_training)
return image, label return image, label
...@@ -109,7 +102,8 @@ def preprocess_image(image, is_training): ...@@ -109,7 +102,8 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -117,35 +111,18 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -117,35 +111,18 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
Returns: Returns:
A tuple of images and labels. A dataset that can be used for iteration.
""" """
dataset = record_dataset(get_filenames(is_training, data_dir)) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because CIFAR-10
# is a relatively small dataset, we choose to shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
dataset = dataset.map(parse_record)
dataset = dataset.map(
lambda image, label: (preprocess_image(image, is_training), label))
dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels return resnet.process_record_dataset(dataset, is_training, batch_size,
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls)
############################################################################### ###############################################################################
......
...@@ -43,8 +43,10 @@ class BaseTest(tf.test.TestCase): ...@@ -43,8 +43,10 @@ class BaseTest(tf.test.TestCase):
data_file.write(fake_data) data_file.write(fake_data)
data_file.close() data_file.close()
fake_dataset = cifar10_main.record_dataset(filename) fake_dataset = tf.data.FixedLengthRecordDataset(
fake_dataset = fake_dataset.map(cifar10_main.parse_record) filename, cifar10_main._RECORD_BYTES)
fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next() image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertEqual(label.get_shape().as_list(), [10]) self.assertEqual(label.get_shape().as_list(), [10])
...@@ -57,7 +59,7 @@ class BaseTest(tf.test.TestCase): ...@@ -57,7 +59,7 @@ class BaseTest(tf.test.TestCase):
for row in image: for row in image:
for pixel in row: for pixel in row:
self.assertAllEqual(pixel, np.array([0, 1, 2])) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def input_fn(self): def input_fn(self):
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3]) features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3])
......
...@@ -35,19 +35,19 @@ _NUM_IMAGES = { ...@@ -35,19 +35,19 @@ _NUM_IMAGES = {
'validation': 50000, 'validation': 50000,
} }
_FILE_SHUFFLE_BUFFER = 1024 _NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 1500 _SHUFFLE_BUFFER = 1500
############################################################################### ###############################################################################
# Data processing # Data processing
############################################################################### ###############################################################################
def filenames(is_training, data_dir): def get_filenames(is_training, data_dir):
"""Return filenames for dataset.""" """Return filenames for dataset."""
if is_training: if is_training:
return [ return [
os.path.join(data_dir, 'train-%05d-of-01024' % i) os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(1024)] for i in range(_NUM_TRAIN_FILES)]
else: else:
return [ return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i) os.path.join(data_dir, 'validation-%05d-of-00128' % i)
...@@ -97,32 +97,33 @@ def parse_record(raw_record, is_training): ...@@ -97,32 +97,33 @@ def parse_record(raw_record, is_training):
return image, tf.one_hot(label, _NUM_CLASSES) return image, tf.one_hot(label, _NUM_CLASSES)
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1,
"""Input function which provides batches for train or eval.""" num_parallel_calls=1):
dataset = tf.data.Dataset.from_tensor_slices( """Input function which provides batches for train or eval.
filenames(is_training, data_dir)) Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
Returns:
A dataset that can be used for iteration.
"""
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
if is_training: if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER) # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
# Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: parse_record(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)
if is_training: return resnet.process_record_dataset(dataset, is_training, batch_size,
# When choosing shuffle buffer sizes, larger sizes result in better _SHUFFLE_BUFFER, parse_record, num_epochs, num_parallel_calls)
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
############################################################################### ###############################################################################
......
...@@ -42,6 +42,57 @@ _BATCH_NORM_DECAY = 0.997 ...@@ -42,6 +42,57 @@ _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5 _BATCH_NORM_EPSILON = 1e-5
################################################################################
# Functions for input processing.
################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1):
"""Given a Dataset with raw records, parse each record into images and labels,
and return an iterator over the records.
Args:
dataset: A Dataset representing raw records
is_training: A boolean denoting whether the input is for training.
batch_size: The number of samples per batch.
shuffle_buffer: The buffer size to use when shuffling records. A larger
value results in better randomness, but smaller values reduce startup
time and use less memory.
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
Returns:
Dataset of (image, label) pairs ready for iteration.
"""
# We prefetch a batch at a time, This can help smooth out the time taken to
# load input files as we go through shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size)
if is_training:
# Shuffle the records. Note that we shuffle before repeating to ensure
# that the shuffling respects epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# If we are training over multiple epochs before evaluating, repeat the
# dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs)
# Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path.
dataset = dataset.prefetch(1)
return dataset
################################################################################ ################################################################################
# Functions building the ResNet model. # Functions building the ResNet model.
################################################################################ ################################################################################
...@@ -494,15 +545,20 @@ def resnet_main(flags, model_function, input_function): ...@@ -494,15 +545,20 @@ def resnet_main(flags, model_function, input_function):
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
print('Starting a training cycle.') print('Starting a training cycle.')
classifier.train(
input_fn=lambda: input_function( def input_fn_train():
True, flags.data_dir, flags.batch_size, flags.epochs_per_eval), return input_function(True, flags.data_dir, flags.batch_size,
hooks=[logging_hook]) flags.epochs_per_eval, flags.num_parallel_calls)
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
print('Starting to evaluate.') print('Starting to evaluate.')
# Evaluate the model and print results # Evaluate the model and print results
eval_results = classifier.evaluate(input_fn=lambda: input_function( def input_fn_eval():
False, flags.data_dir, flags.batch_size)) return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls)
eval_results = classifier.evaluate(input_fn=input_fn_eval)
print(eval_results) print(eval_results)
...@@ -516,6 +572,13 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -516,6 +572,13 @@ class ResnetArgParser(argparse.ArgumentParser):
'--data_dir', type=str, default='/tmp/resnet_data', '--data_dir', type=str, default='/tmp/resnet_data',
help='The directory where the input data is stored.') help='The directory where the input data is stored.')
self.add_argument(
'--num_parallel_calls', type=int, default=5,
help='The number of records that are processed in parallel '
'during input processing. This can be optimized per data set but '
'for generally homogeneous data sets, should be approximately the '
'number of available CPU cores.')
self.add_argument( self.add_argument(
'--model_dir', type=str, default='/tmp/resnet_model', '--model_dir', type=str, default='/tmp/resnet_model',
help='The directory where the model will be stored.') help='The directory where the model will be stored.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment