Commit ee584397 authored by Ayush Dubey's avatar Ayush Dubey Committed by A. Unique TensorFlower
Browse files

Extend synthetic data monkey patch to MultiWorkerMirroredStrategy.

PiperOrigin-RevId: 264244022
parent e4adc6f1
...@@ -287,10 +287,14 @@ def _undo_monkey_patch_dataset_method(strategy): ...@@ -287,10 +287,14 @@ def _undo_monkey_patch_dataset_method(strategy):
def set_up_synthetic_data(): def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core. # TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'): if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else: else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
...@@ -298,10 +302,14 @@ def set_up_synthetic_data(): ...@@ -298,10 +302,14 @@ def set_up_synthetic_data():
def undo_set_up_synthetic_data(): def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core. # TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'): if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else: else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment