"git@developer.sourcefind.cn:change/sglang.git" did not exist on "6beeff41c5b8133d6a964d011f332a9ebb28a12f"
Unverified Commit 481728db authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #5225 from tfboyd/resnet_synthetic_fix

ResNet synthetic data performance enhancement.
parents e0f6a392 967133c1
...@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
) )
def get_synth_input_fn(): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES) _HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype)
############################################################################### ###############################################################################
...@@ -243,8 +243,9 @@ def run_cifar(flags_obj): ...@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
Args: Args:
flags_obj: An object containing parsed flag values. flags_obj: An object containing parsed flag values.
""" """
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and
or input_fn) get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME, flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
......
...@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase): ...@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, resnet_version, dtype): def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn() input_fn = cifar10_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next() features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn( spec = cifar10_main.cifar10_model_fn(
features, labels, mode, { features, labels, mode, {
......
...@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None): ...@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
) )
def get_synth_input_fn(): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES) _DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES,
dtype=dtype)
############################################################################### ###############################################################################
...@@ -331,8 +332,9 @@ def run_imagenet(flags_obj): ...@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
Args: Args:
flags_obj: An object containing parsed flag values. flags_obj: An object containing parsed flag values.
""" """
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and
or input_fn) get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME, flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
......
...@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase): ...@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step() tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn() input_fn = imagenet_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next() features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn( spec = imagenet_main.imagenet_model_fn(
features, labels, mode, { features, labels, mode, {
......
...@@ -108,11 +108,14 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -108,11 +108,14 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return dataset return dataset
def get_synth_input_fn(height, width, num_channels, num_classes): def get_synth_input_fn(height, width, num_channels, num_classes,
"""Returns an input function that returns a dataset with zeroes. dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This is useful in debugging input pipeline performance, as it removes all This input_fn returns a data set that iterates over a set of random data and
elements of file reading and image preprocessing. bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args: Args:
height: Integer height that will be used to create a fake image tensor. height: Integer height that will be used to create a fake image tensor.
...@@ -120,17 +123,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -120,17 +123,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
num_channels: Integer depth that will be used to create a fake image tensor. num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels num_classes: Number of classes that should be represented in the fake labels
tensor tensor
dtype: Data type for features/images.
Returns: Returns:
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument # pylint: disable=unused-argument
return model_helpers.generate_synthetic_data( def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
input_shape=tf.TensorShape([batch_size, height, width, num_channels]), """Returns dataset filled with random data."""
input_dtype=tf.float32, # Synthetic input should be within [0, 255].
label_shape=tf.TensorShape([batch_size]), inputs = tf.truncated_normal(
label_dtype=tf.int32) [batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn return input_fn
...@@ -230,7 +248,7 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -230,7 +248,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images # Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features = tf.cast(features, dtype) features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, resnet_version=resnet_version, model = model_class(resnet_size, data_format, resnet_version=resnet_version,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment