Commit 6dc4ae73 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Return default strategy from get_distribution_strategy when given "off".

Before, it returned None. But almost every use of get_distribution_strategy() assumes an actual strategy is returned and crashes when None is returned. Returning the default strategy fixes these issues and is equivalent to using no strategy, as the default strategy is always in effect when no other strategy is used.

PiperOrigin-RevId: 380951055
parent 0327186b
......@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
"off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
......@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
# Return the default distribution strategy.
return tf.distribute.get_strategy()
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
......
......@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase):
def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off')
self.assertIsNone(ds)
self.assertIs(ds, tf.distribute.get_strategy())
def test_invalid_strategy(self):
with self.assertRaisesRegexp(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment