Unverified Commit 05383c7b authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

Add pure synthetic data to keras resnet model. (#6174)

* Add pure synthetic data to keras resnet mode.

* Add imports.

* Address comments.

* update comment

* Undo set up synthetic data for real data path.

* update comment

* Address comment

* Remove trailing whiltespaces.

* s/make_data_set_iterator/make_dataset_iterator/
parent b2c9e3f5
...@@ -112,6 +112,7 @@ def run(flags_obj): ...@@ -112,6 +112,7 @@ def run(flags_obj):
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT, height=cifar_main.HEIGHT,
width=cifar_main.WIDTH, width=cifar_main.WIDTH,
...@@ -119,6 +120,7 @@ def run(flags_obj): ...@@ -119,6 +120,7 @@ def run(flags_obj):
num_classes=cifar_main.NUM_CLASSES, num_classes=cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj))
else: else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = cifar_main.input_fn input_fn = cifar_main.input_fn
train_input_dataset = input_fn( train_input_dataset = input_fn(
......
...@@ -239,8 +239,13 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -239,8 +239,13 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
maxval=num_classes - 1, maxval=num_classes - 1,
dtype=tf.int32, dtype=tf.int32,
name='synthetic_labels') name='synthetic_labels')
# Cast to float32 for Keras model.
labels = tf.cast(labels, dtype=tf.float32)
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.batch(batch_size)
# `drop_remainder` will make dataset produce outputs with known shapes.
data = data.batch(batch_size, drop_remainder=True)
data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return data return data
......
...@@ -103,6 +103,7 @@ def run(flags_obj): ...@@ -103,6 +103,7 @@ def run(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE, height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
...@@ -110,6 +111,7 @@ def run(flags_obj): ...@@ -110,6 +111,7 @@ def run(flags_obj):
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj))
else: else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(is_training=True, train_input_dataset = input_fn(is_training=True,
......
...@@ -18,6 +18,8 @@ from __future__ import absolute_import ...@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import random
import string
import tensorflow as tf import tensorflow as tf
...@@ -96,3 +98,69 @@ def per_device_batch_size(batch_size, num_gpus): ...@@ -96,3 +98,69 @@ def per_device_batch_size(batch_size, num_gpus):
).format(num_gpus, batch_size, batch_size - remainder) ).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err) raise ValueError(err)
return int(batch_size / num_gpus) return int(batch_size / num_gpus)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel.
with tf.device("device:CPU:0"):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
self._initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.get_local_variable(self.random_name(), initializer=rebatched_t) # pylint: disable=cell-var-from-loop
variable_data.append(v)
self._initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data)
def get_next(self):
return self._input_data
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return "".join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset):
tf.logging.info("Using pure synthetic data.")
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
strategy.org_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_dataset_iterator
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, "org_make_dataset_iterator"):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment