Commit 7b57b6d7 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Fix transformer fp16 test

PiperOrigin-RevId: 263812316
parent 07c09ccc
...@@ -103,9 +103,9 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -103,9 +103,9 @@ class TransformerTaskTest(tf.test.TestCase):
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipIf(tf.test.is_built_with_cuda(), 'TODO(b/139497127): ' @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
'Test is skipped because tf.pad doesn\'t work with GPU.')
def test_train_fp16(self): def test_train_fp16(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment