Commit 76dbcb5a authored by Toby Boyd's avatar Toby Boyd
Browse files

Move tf.cast for fp16 to input pipeline.

parent 5856878d
...@@ -66,7 +66,7 @@ def get_filenames(is_training, data_dir): ...@@ -66,7 +66,7 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')] return [os.path.join(data_dir, 'test_batch.bin')]
def parse_record(raw_record, is_training): def parse_record(raw_record, is_training, dtype):
"""Parse CIFAR-10 image and label from a raw record.""" """Parse CIFAR-10 image and label from a raw record."""
# Convert bytes to a vector of uint8 that is record_bytes long. # Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8) record_vector = tf.decode_raw(raw_record, tf.uint8)
...@@ -85,6 +85,7 @@ def parse_record(raw_record, is_training): ...@@ -85,6 +85,7 @@ def parse_record(raw_record, is_training):
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
image = preprocess_image(image, is_training) image = preprocess_image(image, is_training)
image = tf.cast(image, dtype)
return image, label return image, label
...@@ -107,8 +108,9 @@ def preprocess_image(image, is_training): ...@@ -107,8 +108,9 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. dtype=tf.float32):
"""Input function which provides batches for train or eval.
Args: Args:
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
...@@ -116,6 +118,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -116,6 +118,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training. num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -131,7 +134,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -131,7 +134,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
parse_record_fn=parse_record, parse_record_fn=parse_record,
num_epochs=num_epochs, num_epochs=num_epochs,
num_gpus=num_gpus, num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
dtype=dtype
) )
......
...@@ -129,7 +129,7 @@ def _parse_example_proto(example_serialized): ...@@ -129,7 +129,7 @@ def _parse_example_proto(example_serialized):
return features['image/encoded'], label, bbox return features['image/encoded'], label, bbox
def parse_record(raw_record, is_training): def parse_record(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image. """Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed The input record is parsed into a label and image, and the image is passed
...@@ -139,6 +139,7 @@ def parse_record(raw_record, is_training): ...@@ -139,6 +139,7 @@ def parse_record(raw_record, is_training):
raw_record: scalar Tensor tf.string containing a serialized raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer. Example protocol buffer.
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
dtype: data type to use for images/features.
Returns: Returns:
Tuple with processed image tensor and one-hot-encoded label tensor. Tuple with processed image tensor and one-hot-encoded label tensor.
...@@ -152,11 +153,13 @@ def parse_record(raw_record, is_training): ...@@ -152,11 +153,13 @@ def parse_record(raw_record, is_training):
output_width=_DEFAULT_IMAGE_SIZE, output_width=_DEFAULT_IMAGE_SIZE,
num_channels=_NUM_CHANNELS, num_channels=_NUM_CHANNELS,
is_training=is_training) is_training=is_training)
image = tf.cast(image, dtype)
return image, label return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
dtype=tf.float32):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -165,6 +168,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -165,6 +168,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training. num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -192,7 +196,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -192,7 +196,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
parse_record_fn=parse_record, parse_record_fn=parse_record,
num_epochs=num_epochs, num_epochs=num_epochs,
num_gpus=num_gpus, num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
dtype=dtype
) )
......
...@@ -45,7 +45,7 @@ from official.utils.misc import model_helpers ...@@ -45,7 +45,7 @@ from official.utils.misc import model_helpers
################################################################################ ################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_gpus=None, parse_record_fn, num_epochs=1, num_gpus=None,
examples_per_epoch=None): examples_per_epoch=None, dtype=tf.float32):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training. num_gpus: The number of gpus used for training.
examples_per_epoch: The number of examples in an epoch. examples_per_epoch: The number of examples in an epoch.
dtype: Data type to use for images/features.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
...@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# batch_size is almost always much greater than the number of CPU cores. # batch_size is almost always much greater than the number of CPU cores.
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.map_and_batch( tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training), lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size, batch_size=batch_size,
num_parallel_batches=1, num_parallel_batches=1,
drop_remainder=False)) drop_remainder=False))
...@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images # Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove. # Checks that features/images have same data type being used for calculations.
features = tf.cast(features, dtype) assert features.dtype == dtype
model = model_class(resnet_size, data_format, resnet_version=resnet_version, model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype) dtype=dtype)
...@@ -454,14 +455,16 @@ def resnet_main( ...@@ -454,14 +455,16 @@ def resnet_main(
batch_size=distribution_utils.per_device_batch_size( batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs, num_epochs=num_epochs,
num_gpus=flags_core.get_num_gpus(flags_obj)) num_gpus=flags_core.get_num_gpus(flags_obj),
dtype=flags_core.get_tf_dtype(flags_obj))
def input_fn_eval(): def input_fn_eval():
return input_function( return input_function(
is_training=False, data_dir=flags_obj.data_dir, is_training=False, data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size( batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1) num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj))
if flags_obj.eval_only or not flags_obj.train_epochs: if flags_obj.eval_only or not flags_obj.train_epochs:
# If --eval_only is set, perform a single loop with zero train epochs. # If --eval_only is set, perform a single loop with zero train epochs.
...@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None):
'If not None initialize all the network except the final layer with ' 'If not None initialize all the network except the final layer with '
'these values')) 'these values'))
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="eval_only", default=False, name='eval_only', default=False,
help=flags_core.help_wrap('Skip training and only perform evaluation on ' help=flags_core.help_wrap('Skip training and only perform evaluation on '
'the latest checkpoint.')) 'the latest checkpoint.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment