Commit 1f2cebfa authored by Priya Gupta's avatar Priya Gupta Committed by A. Unique TensorFlower
Browse files

fix monkey patch for synthetic data for resnet keras model.

PiperOrigin-RevId: 263854996
parent 5a309240
...@@ -205,38 +205,64 @@ class SyntheticDataset(object): ...@@ -205,38 +205,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device.""" """A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1): def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel. # dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'): with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1)) tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor) flat_tensor = tf.nest.flatten(tensor)
variable_data = [] variable_data = []
self._initializers = [] initializers = []
for t in flat_tensor: for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0] rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self.random_name(), v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t) initializer=rebatched_t)
variable_data.append(v) variable_data.append(v)
self._initializers.append(v.initializer) initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data) input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self): def get_next(self):
return self._input_data return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self): def initialize(self):
if tf.executing_eagerly(): if tf.executing_eagerly():
return tf.no_op() return tf.no_op()
else: else:
return self._initializers return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset): def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.') tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
...@@ -244,13 +270,21 @@ def _monkey_patch_dataset_method(strategy): ...@@ -244,13 +270,21 @@ def _monkey_patch_dataset_method(strategy):
else: else:
return SyntheticDataset(dataset) return SyntheticDataset(dataset)
strategy.org_make_dataset_iterator = strategy.make_dataset_iterator def make_iterator(self, dataset):
strategy.make_dataset_iterator = make_dataset_iterator dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy): def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'org_make_dataset_iterator'): if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data(): def set_up_synthetic_data():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment