Commit d1ed379e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411069186
parent bd7732da
......@@ -14,10 +14,13 @@
"""Tests for distribution util functions."""
import sys
import tensorflow as tf
from official.common import distribute_utils
TPU_TEST = 'test_tpu' in sys.argv[0]
class DistributeUtilsTest(tf.test.TestCase):
"""Tests for distribute util functions."""
......@@ -51,6 +54,9 @@ class DistributeUtilsTest(tf.test.TestCase):
self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self):
# CPU only.
_ = distribute_utils.get_distribution_strategy(num_gpus=0)
# 5 GPUs.
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5)
......@@ -78,10 +84,26 @@ class DistributeUtilsTest(tf.test.TestCase):
self.assertIsInstance(
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
with self.assertRaisesRegex(
ValueError,
'When used with `multi_worker_mirrored`, valid values.*'):
_ = distribute_utils.get_distribution_strategy(
'multi_worker_mirrored', all_reduce_alg='dummy')
def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off')
self.assertIs(ds, tf.distribute.get_strategy())
def test_tpu_strategy(self):
if not TPU_TEST:
self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
with self.assertRaises(ValueError):
_ = distribute_utils.get_distribution_strategy('tpu')
ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
self.assertIsInstance(
ds, tf.distribute.TPUStrategy)
def test_invalid_strategy(self):
with self.assertRaisesRegexp(
ValueError,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment