Commit dd376f53 authored by Igor Saprykin's avatar Igor Saprykin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 263463300
parent c8660848
...@@ -80,10 +80,14 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -80,10 +80,14 @@ class TransformerTaskTest(tf.test.TestCase):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
def test_train_no_dist_strat(self): def test_train_no_dist_strat(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
def test_train_static_batch(self): def test_train_static_batch(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
...@@ -105,8 +109,8 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -105,8 +109,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self): def test_train_2_gpu(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -117,8 +121,8 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -117,8 +121,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu_fp16(self): def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -153,16 +157,22 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -153,16 +157,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS(update_flags) FLAGS(update_flags)
def test_predict(self): def test_predict(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_predict_fp16(self): def test_predict_fp16(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16') self._prepare_files_and_flags('--dtype=fp16')
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_eval(self): def test_eval(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.eval() t.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment