Commit 4864e795 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Fix issue where distribution_strategy="off" did not work.

Also improve error message when distribution=off is specified without properly quoting "off"

PiperOrigin-RevId: 359395329
parent 1fc6b4f5
...@@ -127,6 +127,15 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -127,6 +127,15 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus < 0: if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.") raise ValueError("`num_gpus` can not be negative.")
if not isinstance(distribution_strategy, str):
msg = ("distribution_strategy must be a string but got: %s." %
(distribution_strategy,))
if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
msg += (" If you meant to pass the string 'off', make sure you add "
"quotes around 'off' so that yaml interprets it as a string "
"instead of a bool.")
raise ValueError(msg)
distribution_strategy = distribution_strategy.lower() distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off": if distribution_strategy == "off":
if num_gpus > 1: if num_gpus > 1:
......
...@@ -41,6 +41,19 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -41,6 +41,19 @@ class GetDistributionStrategyTest(tf.test.TestCase):
for device in ds.extended.worker_devices: for device in ds.extended.worker_devices:
self.assertIn('GPU', device) self.assertIn('GPU', device)
def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off')
self.assertIsNone(ds)
def test_invalid_strategy(self):
with self.assertRaisesRegexp(
ValueError,
'distribution_strategy must be a string but got: False. If'):
distribute_utils.get_distribution_strategy(False)
with self.assertRaisesRegexp(
ValueError, 'distribution_strategy must be a string but got: 1'):
distribute_utils.get_distribution_strategy(1)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment