Unverified Commit a1ee97e6 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

V2 contrib tweaks (#6184)

* Remove contrib thread pool.

* Remove commented out contrib import.

* Fix lint issues.

* move tf.data.options higher. Tweak line breaks.

* do not monkey patch on or off if dist_strat is off

* Do not monkey patch if no_dist_strat.

* Fix file permissions.

* fix file permissions.

* Revert change to main.  Add hasattr(tf, 'contrib') to utils

* compat.v1.logging

* tf.compat.v1.get_local_variables.
parent 7e056690
...@@ -48,18 +48,18 @@ def get_distribution_strategy(num_gpus, ...@@ -48,18 +48,18 @@ def get_distribution_strategy(num_gpus,
if turn_off_distribution_strategy: if turn_off_distribution_strategy:
return None return None
else: else:
return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0") return tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
elif num_gpus == 1: elif num_gpus == 1:
if turn_off_distribution_strategy: if turn_off_distribution_strategy:
return None return None
else: else:
return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0") return tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
elif turn_off_distribution_strategy: elif turn_off_distribution_strategy:
raise ValueError("When {} GPUs are specified, " raise ValueError('When {} GPUs are specified, '
"turn_off_distribution_strategy flag cannot be set to" 'turn_off_distribution_strategy flag cannot be set to'
"True.".format(num_gpus)) 'True.'.format(num_gpus))
else: # num_gpus > 1 and not turn_off_distribution_strategy else: # num_gpus > 1 and not turn_off_distribution_strategy
devices = ["device:GPU:%d" % i for i in range(num_gpus)] devices = ['device:GPU:%d' % i for i in range(num_gpus)]
if all_reduce_alg: if all_reduce_alg:
return tf.distribute.MirroredStrategy( return tf.distribute.MirroredStrategy(
devices=devices, devices=devices,
...@@ -92,13 +92,14 @@ def per_device_batch_size(batch_size, num_gpus): ...@@ -92,13 +92,14 @@ def per_device_batch_size(batch_size, num_gpus):
remainder = batch_size % num_gpus remainder = batch_size % num_gpus
if remainder: if remainder:
err = ("When running with multiple GPUs, batch size " err = ('When running with multiple GPUs, batch size '
"must be a multiple of the number of available GPUs. Found {} " 'must be a multiple of the number of available GPUs. Found {} '
"GPUs with a batch size of {}; try --batch_size={} instead." 'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder) ).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err) raise ValueError(err)
return int(batch_size / num_gpus) return int(batch_size / num_gpus)
# The `SyntheticDataset` is a temporary solution for generating synthetic data # The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution # directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy # Strategies. We will have better support in `tf.data` or Distribution Strategy
...@@ -109,7 +110,7 @@ class SyntheticDataset(object): ...@@ -109,7 +110,7 @@ class SyntheticDataset(object):
def __init__(self, dataset, split_by=1): def __init__(self, dataset, split_by=1):
self._input_data = {} self._input_data = {}
# dataset.take(1) doesn't have GPU kernel. # dataset.take(1) doesn't have GPU kernel.
with tf.device("device:CPU:0"): with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1)) tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor) flat_tensor = tf.nest.flatten(tensor)
variable_data = [] variable_data = []
...@@ -117,7 +118,8 @@ class SyntheticDataset(object): ...@@ -117,7 +118,8 @@ class SyntheticDataset(object):
for t in flat_tensor: for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0] rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.get_local_variable(self.random_name(), initializer=rebatched_t) # pylint: disable=cell-var-from-loop v = tf.compat.v1.get_local_variable(self.random_name(),
initializer=rebatched_t)
variable_data.append(v) variable_data.append(v)
self._initializers.append(v.initializer) self._initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data) self._input_data = tf.nest.pack_sequence_as(tensor, variable_data)
...@@ -132,13 +134,13 @@ class SyntheticDataset(object): ...@@ -132,13 +134,13 @@ class SyntheticDataset(object):
return self._initializers return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits): def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return "".join(random.choice(chars) for _ in range(size)) return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset): def make_dataset_iterator(self, dataset):
tf.logging.info("Using pure synthetic data.") tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync) return SyntheticDataset(dataset, self.num_replicas_in_sync)
...@@ -150,17 +152,25 @@ def _monkey_patch_dataset_method(strategy): ...@@ -150,17 +152,25 @@ def _monkey_patch_dataset_method(strategy):
def _undo_monkey_patch_dataset_method(strategy): def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, "org_make_dataset_iterator"): if hasattr(strategy, 'org_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator strategy.make_dataset_iterator = strategy.org_make_dataset_iterator
def set_up_synthetic_data(): def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
def undo_set_up_synthetic_data(): def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment