Unverified Commit 481728db authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #5225 from tfboyd/resnet_synthetic_fix

ResNet synthetic data performance enhancement.
parents e0f6a392 967133c1
......@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
def get_synth_input_fn():
def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype)
###############################################################################
......@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
input_function = (flags_obj.use_synthetic_data and
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
......
......@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn()
input_fn = cifar10_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
......
......@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
def get_synth_input_fn():
def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES)
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES,
dtype=dtype)
###############################################################################
......@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
input_function = (flags_obj.use_synthetic_data and
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
......
......@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn()
input_fn = imagenet_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn(
features, labels, mode, {
......
......@@ -108,11 +108,14 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return dataset
def get_synth_input_fn(height, width, num_channels, num_classes):
"""Returns an input function that returns a dataset with zeroes.
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This is useful in debugging input pipeline performance, as it removes all
elements of file reading and image preprocessing.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
......@@ -120,17 +123,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
return model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([batch_size, height, width, num_channels]),
input_dtype=tf.float32,
label_shape=tf.TensorShape([batch_size]),
label_dtype=tf.int32)
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
......@@ -230,7 +248,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment