Commit 296d0d3f authored by guptapriya's avatar guptapriya
Browse files

Unskip tests with 1.x

parent 3a796b5a
......@@ -79,25 +79,23 @@ class TransformerTaskTest(tf.test.TestCase):
def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath))
def test_train(self):
def test_train_no_dist_strat(self):
t = tm.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_static_batch(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_1_gpu_with_dist_strat(self):
FLAGS.distribution_strategy = 'one_device'
t = tm.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_2_gpu(self):
if context.num_gpus() < 2:
self.skipTest(
......@@ -110,7 +108,6 @@ class TransformerTaskTest(tf.test.TestCase):
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2:
self.skipTest(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment