"tests/pytorch/vscode:/vscode.git/clone" did not exist on "fa5ff2fceb1855c1e162dfb3a2bdb54fbc43c265"
Commit d1ed379e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411069186
parent bd7732da
...@@ -14,10 +14,13 @@ ...@@ -14,10 +14,13 @@
"""Tests for distribution util functions.""" """Tests for distribution util functions."""
import sys
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils from official.common import distribute_utils
TPU_TEST = 'test_tpu' in sys.argv[0]
class DistributeUtilsTest(tf.test.TestCase): class DistributeUtilsTest(tf.test.TestCase):
"""Tests for distribute util functions.""" """Tests for distribute util functions."""
...@@ -51,6 +54,9 @@ class DistributeUtilsTest(tf.test.TestCase): ...@@ -51,6 +54,9 @@ class DistributeUtilsTest(tf.test.TestCase):
self.assertIn('GPU', ds.extended.worker_devices[0]) self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self): def test_mirrored_strategy(self):
# CPU only.
_ = distribute_utils.get_distribution_strategy(num_gpus=0)
# 5 GPUs.
ds = distribute_utils.get_distribution_strategy(num_gpus=5) ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5) self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5) self.assertEquals(len(ds.extended.worker_devices), 5)
...@@ -78,10 +84,26 @@ class DistributeUtilsTest(tf.test.TestCase): ...@@ -78,10 +84,26 @@ class DistributeUtilsTest(tf.test.TestCase):
self.assertIsInstance( self.assertIsInstance(
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy) ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
with self.assertRaisesRegex(
ValueError,
'When used with `multi_worker_mirrored`, valid values.*'):
_ = distribute_utils.get_distribution_strategy(
'multi_worker_mirrored', all_reduce_alg='dummy')
def test_no_strategy(self): def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off') ds = distribute_utils.get_distribution_strategy('off')
self.assertIs(ds, tf.distribute.get_strategy()) self.assertIs(ds, tf.distribute.get_strategy())
def test_tpu_strategy(self):
if not TPU_TEST:
self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
with self.assertRaises(ValueError):
_ = distribute_utils.get_distribution_strategy('tpu')
ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
self.assertIsInstance(
ds, tf.distribute.TPUStrategy)
def test_invalid_strategy(self): def test_invalid_strategy(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment