Commit c9972ad6 authored by Toby Boyd's avatar Toby Boyd
Browse files

Improve synthic data performance

parent 23b5b422
......@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
def get_synth_input_fn():
def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype)
###############################################################################
......@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
input_function = (flags_obj.use_synthetic_data and
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
......
......@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn()
input_fn = cifar10_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
......
......@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
def get_synth_input_fn():
def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES)
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES,
dtype=dtype)
###############################################################################
......@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
input_function = (flags_obj.use_synthetic_data and
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
......
......@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn()
input_fn = imagenet_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn(
features, labels, mode, {
......
......@@ -108,11 +108,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return dataset
def get_synth_input_fn(height, width, num_channels, num_classes):
"""Returns an input function that returns a dataset with zeroes.
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This is useful in debugging input pipeline performance, as it removes all
elements of file reading and image preprocessing.
This input_fn removed all aspects of the input pipeline other than the
host to device copy. This is useful in debugging input pipeline performance.
Args:
height: Integer height that will be used to create a fake image tensor.
......@@ -120,17 +121,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
return model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([batch_size, height, width, num_channels]),
input_dtype=tf.float32,
label_shape=tf.TensorShape([batch_size]),
label_dtype=tf.int32)
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
......@@ -230,7 +246,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment