Commit 65a967e9 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410013894
parent 30e6e03f
...@@ -141,8 +141,8 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -141,8 +141,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy = distribution_strategy.lower() distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off": if distribution_strategy == "off":
if num_gpus > 1: if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy " raise ValueError(f"When {num_gpus} GPUs are specified, "
"flag cannot be set to `off`.".format(num_gpus)) "distribution_strategy flag cannot be set to `off`.")
# Return the default distribution strategy. # Return the default distribution strategy.
return tf.distribute.get_strategy() return tf.distribute.get_strategy()
......
...@@ -12,24 +12,40 @@ ...@@ -12,24 +12,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Tests for distribution util functions.""" """Tests for distribution util functions."""
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils from official.common import distribute_utils
class GetDistributionStrategyTest(tf.test.TestCase): class DistributeUtilsTest(tf.test.TestCase):
"""Tests for get_distribution_strategy.""" """Tests for distribute util functions."""
def test_invalid_args(self):
with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
_ = distribute_utils.get_distribution_strategy(num_gpus=-1)
with self.assertRaisesRegex(ValueError,
'.*If you meant to pass the string .*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy=False, num_gpus=0)
with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='off', num_gpus=2)
with self.assertRaisesRegex(ValueError,
'`OneDeviceStrategy` can not be used.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='one_device', num_gpus=2)
def test_one_device_strategy_cpu(self): def test_one_device_strategy_cpu(self):
ds = distribute_utils.get_distribution_strategy(num_gpus=0) ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1) self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0]) self.assertIn('CPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self): def test_one_device_strategy_gpu(self):
ds = distribute_utils.get_distribution_strategy(num_gpus=1) ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1) self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0]) self.assertIn('GPU', ds.extended.worker_devices[0])
...@@ -41,6 +57,27 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -41,6 +57,27 @@ class GetDistributionStrategyTest(tf.test.TestCase):
for device in ds.extended.worker_devices: for device in ds.extended.worker_devices:
self.assertIn('GPU', device) self.assertIn('GPU', device)
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored',
num_gpus=2,
all_reduce_alg='nccl',
num_packs=2)
with self.assertRaisesRegex(
ValueError,
'When used with `mirrored`, valid values for all_reduce_alg are.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored',
num_gpus=2,
all_reduce_alg='dummy',
num_packs=2)
def test_mwms(self):
distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
ds = distribute_utils.get_distribution_strategy(
'multi_worker_mirrored', all_reduce_alg='nccl')
self.assertIsInstance(
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
def test_no_strategy(self): def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off') ds = distribute_utils.get_distribution_strategy('off')
self.assertIs(ds, tf.distribute.get_strategy()) self.assertIs(ds, tf.distribute.get_strategy())
...@@ -54,6 +91,12 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -54,6 +91,12 @@ class GetDistributionStrategyTest(tf.test.TestCase):
ValueError, 'distribution_strategy must be a string but got: 1'): ValueError, 'distribution_strategy must be a string but got: 1'):
distribute_utils.get_distribution_strategy(1) distribute_utils.get_distribution_strategy(1)
def test_get_strategy_scope(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
with distribute_utils.get_strategy_scope(ds):
self.assertIs(tf.distribute.get_strategy(), ds)
with distribute_utils.get_strategy_scope(None):
self.assertIsNot(tf.distribute.get_strategy(), ds)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment