"docs/vscode:/vscode.git/clone" did not exist on "4fd6e7103006e01f7a4f5d723b13ea0e789ff3ce"
Commit c9972ad6 authored by Toby Boyd's avatar Toby Boyd
Browse files

Improve synthic data performance

parent 23b5b422
...@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
) )
def get_synth_input_fn(): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES) _HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype)
############################################################################### ###############################################################################
...@@ -243,8 +243,9 @@ def run_cifar(flags_obj): ...@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
Args: Args:
flags_obj: An object containing parsed flag values. flags_obj: An object containing parsed flag values.
""" """
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and
or input_fn) get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME, flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
......
...@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase): ...@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, resnet_version, dtype): def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn() input_fn = cifar10_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next() features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn( spec = cifar10_main.cifar10_model_fn(
features, labels, mode, { features, labels, mode, {
......
...@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
) )
def get_synth_input_fn(): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES) _DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES,
dtype=dtype)
############################################################################### ###############################################################################
...@@ -331,8 +332,9 @@ def run_imagenet(flags_obj): ...@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
Args: Args:
flags_obj: An object containing parsed flag values. flags_obj: An object containing parsed flag values.
""" """
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and
or input_fn) get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME, flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
......
...@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase): ...@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step() tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn() input_fn = imagenet_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next() features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn( spec = imagenet_main.imagenet_model_fn(
features, labels, mode, { features, labels, mode, {
......
...@@ -108,11 +108,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -108,11 +108,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return dataset return dataset
def get_synth_input_fn(height, width, num_channels, num_classes): def get_synth_input_fn(height, width, num_channels, num_classes,
"""Returns an input function that returns a dataset with zeroes. dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This is useful in debugging input pipeline performance, as it removes all This input_fn removed all aspects of the input pipeline other than the
elements of file reading and image preprocessing. host to device copy. This is useful in debugging input pipeline performance.
Args: Args:
height: Integer height that will be used to create a fake image tensor. height: Integer height that will be used to create a fake image tensor.
...@@ -120,17 +121,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -120,17 +121,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
num_channels: Integer depth that will be used to create a fake image tensor. num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels num_classes: Number of classes that should be represented in the fake labels
tensor tensor
dtype: Data type for features/images.
Returns: Returns:
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument # pylint: disable=unused-argument
return model_helpers.generate_synthetic_data( def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
input_shape=tf.TensorShape([batch_size, height, width, num_channels]), """Returns dataset filled with random data."""
input_dtype=tf.float32, # Synthetic input should be within [0, 255].
label_shape=tf.TensorShape([batch_size]), inputs = tf.truncated_normal(
label_dtype=tf.int32) [batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn return input_fn
...@@ -230,7 +246,7 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -230,7 +246,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images # Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features = tf.cast(features, dtype) features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, resnet_version=resnet_version, model = model_class(resnet_size, data_format, resnet_version=resnet_version,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment