Unverified Commit a1ee97e6 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

V2 contrib tweaks (#6184)

* Remove contrib thread pool.

* Remove commented out contrib import.

* Fix lint issues.

* move tf.data.options higher. Tweak line breaks.

* do not monkey patch on or off if dist_strat is off

* Do not monkey patch if no_dist_strat.

* Fix file permissions.

* fix file permissions.

* Revert change to main.  Add hasattr(tf, 'contrib') to utils

* compat.v1.logging

* tf.compat.v1.get_local_variables.
parent 7e056690
......@@ -48,18 +48,18 @@ def get_distribution_strategy(num_gpus,
if turn_off_distribution_strategy:
return None
else:
return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0")
return tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
elif num_gpus == 1:
if turn_off_distribution_strategy:
return None
else:
return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0")
return tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
elif turn_off_distribution_strategy:
raise ValueError("When {} GPUs are specified, "
"turn_off_distribution_strategy flag cannot be set to"
"True.".format(num_gpus))
raise ValueError('When {} GPUs are specified, '
'turn_off_distribution_strategy flag cannot be set to'
'True.'.format(num_gpus))
else: # num_gpus > 1 and not turn_off_distribution_strategy
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
devices = ['device:GPU:%d' % i for i in range(num_gpus)]
if all_reduce_alg:
return tf.distribute.MirroredStrategy(
devices=devices,
......@@ -92,13 +92,14 @@ def per_device_batch_size(batch_size, num_gpus):
remainder = batch_size % num_gpus
if remainder:
err = ("When running with multiple GPUs, batch size "
"must be a multiple of the number of available GPUs. Found {} "
"GPUs with a batch size of {}; try --batch_size={} instead."
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
......@@ -109,7 +110,7 @@ class SyntheticDataset(object):
def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel.
with tf.device("device:CPU:0"):
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
......@@ -117,7 +118,8 @@ class SyntheticDataset(object):
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.get_local_variable(self.random_name(), initializer=rebatched_t) # pylint: disable=cell-var-from-loop
v = tf.compat.v1.get_local_variable(self.random_name(),
initializer=rebatched_t)
variable_data.append(v)
self._initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data)
......@@ -132,13 +134,13 @@ class SyntheticDataset(object):
return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return "".join(random.choice(chars) for _ in range(size))
return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset):
tf.logging.info("Using pure synthetic data.")
tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
......@@ -150,17 +152,25 @@ def _monkey_patch_dataset_method(strategy):
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, "org_make_dataset_iterator"):
if hasattr(strategy, 'org_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment