Commit dd376f53 authored by Igor Saprykin's avatar Igor Saprykin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 263463300
parent c8660848
......@@ -80,10 +80,14 @@ class TransformerTaskTest(tf.test.TestCase):
self.assertTrue(os.path.exists(filepath))
def test_train_no_dist_strat(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = tm.TransformerTask(FLAGS)
t.train()
def test_train_static_batch(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device'
FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS)
......@@ -105,8 +109,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(2, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2
FLAGS.param_set = 'base'
......@@ -117,8 +121,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(2, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2
FLAGS.param_set = 'base'
......@@ -153,16 +157,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS(update_flags)
def test_predict(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
t.predict()
def test_predict_fp16(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16')
t = tm.TransformerTask(FLAGS)
t.predict()
def test_eval(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
t.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment