"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "2e77bb3eed670ce133984b110b740294daca44ee"
Commit 296d0d3f authored by guptapriya's avatar guptapriya
Browse files

Unskip tests with 1.x

parent 3a796b5a
...@@ -79,25 +79,23 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -79,25 +79,23 @@ class TransformerTaskTest(tf.test.TestCase):
def _assert_exists(self, filepath): def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
def test_train(self): def test_train_no_dist_strat(self):
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_static_batch(self): def test_train_static_batch(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_1_gpu_with_dist_strat(self): def test_train_1_gpu_with_dist_strat(self):
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_2_gpu(self): def test_train_2_gpu(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
...@@ -110,7 +108,6 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -110,7 +108,6 @@ class TransformerTaskTest(tf.test.TestCase):
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_train_2_gpu_fp16(self): def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment