Unverified Commit 7b304676 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #5253 from tfboyd/resnet_input_pipeline_fp16

Move tf.cast to tf.float16 in input pipeline 
parents 7babedc5 5c0c749b
......@@ -66,7 +66,7 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')]
def parse_record(raw_record, is_training):
def parse_record(raw_record, is_training, dtype):
"""Parse CIFAR-10 image and label from a raw record."""
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)
......@@ -85,6 +85,7 @@ def parse_record(raw_record, is_training):
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
image = preprocess_image(image, is_training)
image = tf.cast(image, dtype)
return image, label
......@@ -107,8 +108,9 @@ def preprocess_image(image, is_training):
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
dtype=tf.float32):
"""Input function which provides batches for train or eval.
Args:
is_training: A boolean denoting whether the input is for training.
......@@ -116,6 +118,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
Returns:
A dataset that can be used for iteration.
......@@ -131,7 +134,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
parse_record_fn=parse_record,
num_epochs=num_epochs,
num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
dtype=dtype
)
......
......@@ -61,7 +61,7 @@ class BaseTest(tf.test.TestCase):
fake_dataset = tf.data.FixedLengthRecordDataset(
filename, cifar10_main._RECORD_BYTES) # pylint: disable=protected-access
fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False))
lambda val: cifar10_main.parse_record(val, False, tf.float32))
image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertAllEqual(label.shape, ())
......
......@@ -129,7 +129,7 @@ def _parse_example_proto(example_serialized):
return features['image/encoded'], label, bbox
def parse_record(raw_record, is_training):
def parse_record(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
......@@ -139,6 +139,7 @@ def parse_record(raw_record, is_training):
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: data type to use for images/features.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
......@@ -152,11 +153,13 @@ def parse_record(raw_record, is_training):
output_width=_DEFAULT_IMAGE_SIZE,
num_channels=_NUM_CHANNELS,
is_training=is_training)
image = tf.cast(image, dtype)
return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
dtype=tf.float32):
"""Input function which provides batches for train or eval.
Args:
......@@ -165,6 +168,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
Returns:
A dataset that can be used for iteration.
......@@ -192,7 +196,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
parse_record_fn=parse_record,
num_epochs=num_epochs,
num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
dtype=dtype
)
......
......@@ -45,7 +45,7 @@ from official.utils.misc import model_helpers
################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_gpus=None,
examples_per_epoch=None):
examples_per_epoch=None, dtype=tf.float32):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
......@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
examples_per_epoch: The number of examples in an epoch.
dtype: Data type to use for images/features.
Returns:
Dataset of (image, label) pairs ready for iteration.
......@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# batch_size is almost always much greater than the number of CPU cores.
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training),
lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size,
num_parallel_batches=1,
drop_remainder=False))
......@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features = tf.cast(features, dtype)
# Checks that features/images have same data type being used for calculations.
assert features.dtype == dtype
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype)
......@@ -454,14 +455,16 @@ def resnet_main(
batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs,
num_gpus=flags_core.get_num_gpus(flags_obj))
num_gpus=flags_core.get_num_gpus(flags_obj),
dtype=flags_core.get_tf_dtype(flags_obj))
def input_fn_eval():
return input_function(
is_training=False, data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1)
num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj))
if flags_obj.eval_only or not flags_obj.train_epochs:
# If --eval_only is set, perform a single loop with zero train epochs.
......@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None):
'If not None initialize all the network except the final layer with '
'these values'))
flags.DEFINE_boolean(
name="eval_only", default=False,
name='eval_only', default=False,
help=flags_core.help_wrap('Skip training and only perform evaluation on '
'the latest checkpoint.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment