Unverified Commit 74c43aae authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Add synthetic data option to Resnet (#3503)

* Adding option to run with synthetic data.

* Adding option to run with synthetic data.

Adding option to run with synthetic data.

Adding option to run with synthetic data.

Debugging

Debugging

Debugging

Removing dataset

Removing dataset

Updating comments

Updating tests

Updating tests

Clarifying name of fn

Tests

* Copy pasta

* Using dataset as recommended by mrry

* Updating tests to use datasets
parent a8bb5926
......@@ -131,6 +131,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
def get_synth_input_fn():
return resnet.get_synth_input_fn(_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
###############################################################################
# Running the model
###############################################################################
......@@ -200,7 +204,8 @@ def cifar10_model_fn(features, labels, mode, params):
def main(unused_argv):
resnet.resnet_main(FLAGS, cifar10_model_fn, input_fn)
input_function = FLAGS.use_synthetic_data and get_synth_input_fn() or input_fn
resnet.resnet_main(FLAGS, cifar10_model_fn, input_function)
if __name__ == '__main__':
......
......@@ -64,14 +64,11 @@ class BaseTest(tf.test.TestCase):
for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def input_fn(self):
features = tf.random_uniform([_BATCH_SIZE, _HEIGHT, _WIDTH, _NUM_CHANNELS])
labels = tf.random_uniform(
[_BATCH_SIZE], maxval=9, dtype=tf.int32)
return features, tf.one_hot(labels, 10)
def cifar10_model_fn_helper(self, mode, multi_gpu=False):
features, labels = self.input_fn()
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'resnet_size': 32,
......
......@@ -148,9 +148,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet.process_record_dataset(dataset, is_training, batch_size,
_SHUFFLE_BUFFER, parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
return resnet.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs, num_parallel_calls, examples_per_epoch=num_images,
multi_gpu=multi_gpu)
def get_synth_input_fn():
return resnet.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES)
###############################################################################
......@@ -235,7 +241,8 @@ def imagenet_model_fn(features, labels, mode, params):
def main(unused_argv):
resnet.resnet_main(FLAGS, imagenet_model_fn, input_fn)
input_function = FLAGS.use_synthetic_data and get_synth_input_fn() or input_fn
resnet.resnet_main(FLAGS, imagenet_model_fn, input_function)
if __name__ == '__main__':
......
......@@ -125,22 +125,14 @@ class BaseTest(tf.test.TestCase):
def test_tensor_shapes_resnet_200_with_gpu(self):
self.tensor_shapes_helper(200, True)
def input_fn(self):
"""Provides random features and labels."""
features = tf.random_uniform([_BATCH_SIZE, 224, 224, 3])
labels = tf.one_hot(
tf.random_uniform(
[_BATCH_SIZE], maxval=_LABEL_CLASSES - 1,
dtype=tf.int32),
_LABEL_CLASSES)
return features, labels
def resnet_model_fn_helper(self, mode, multi_gpu=False):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step()
features, labels = self.input_fn()
input_fn = imagenet_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn(
features, labels, mode, {
'resnet_size': 50,
......
......@@ -113,6 +113,31 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return dataset
def get_synth_input_fn(height, width, num_channels, num_classes):
"""Returns an input function that returns a dataset with zeroes.
This is useful in debugging input pipeline performance, as it removes all
elements of file reading and image preprocessing.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
def input_fn(is_training, data_dir, batch_size, *args):
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat()
return input_fn
################################################################################
# Functions building the ResNet model.
################################################################################
......@@ -673,5 +698,11 @@ class ResnetArgParser(argparse.ArgumentParser):
self.add_argument(
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs. Note that this is '
'superseded by the --num_gpus flag.')
help='If set, run across all available GPUs.')
# Advanced args
self.add_argument(
'--use_synthetic_data', action='store_true',
help='If set, use fake data (zeroes) instead of a real dataset. '
'This mode is useful for performance debugging, as it removes '
'input processing steps, but will not learn anything.')
......@@ -64,7 +64,6 @@ def _random_crop_and_flip(image, crop_height, crop_width):
height, width = _get_h_w(image)
# Create a random bounding box.
#
# Use tf.random_uniform and not numpy.random.rand as doing the former would
# generate random numbers at graph eval time, unlike the latter which
# generates random numbers at graph definition time.
......@@ -79,6 +78,7 @@ def _random_crop_and_flip(image, crop_height, crop_width):
cropped = tf.image.random_flip_left_right(cropped)
return cropped
def _central_crop(image, crop_height, crop_width):
"""Performs central crops of the given image list.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment