"vscode:/vscode.git/clone" did not exist on "3b31b72454a422aae30e16cfa5aa0e7b2b43d6b7"
Commit 1f2cebfa authored by Priya Gupta's avatar Priya Gupta Committed by A. Unique TensorFlower
Browse files

fix monkey patch for synthetic data for resnet keras model.

PiperOrigin-RevId: 263854996
parent 5a309240
...@@ -205,38 +205,64 @@ class SyntheticDataset(object): ...@@ -205,38 +205,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device.""" """A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1): def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel. # dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'): with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1)) tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor) flat_tensor = tf.nest.flatten(tensor)
variable_data = [] variable_data = []
self._initializers = [] initializers = []
for t in flat_tensor: for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0] rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self.random_name(), v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t) initializer=rebatched_t)
variable_data.append(v) variable_data.append(v)
self._initializers.append(v.initializer) initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data) input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self): def get_next(self):
return self._input_data return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self): def initialize(self):
if tf.executing_eagerly(): if tf.executing_eagerly():
return tf.no_op() return tf.no_op()
else: else:
return self._initializers return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset): def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.') tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
...@@ -244,13 +270,21 @@ def _monkey_patch_dataset_method(strategy): ...@@ -244,13 +270,21 @@ def _monkey_patch_dataset_method(strategy):
else: else:
return SyntheticDataset(dataset) return SyntheticDataset(dataset)
strategy.org_make_dataset_iterator = strategy.make_dataset_iterator def make_iterator(self, dataset):
strategy.make_dataset_iterator = make_dataset_iterator dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy): def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'org_make_dataset_iterator'): if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data(): def set_up_synthetic_data():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment