Commit 2e2584fc authored by lcskrishna's avatar lcskrishna
Browse files

skip tests that are failing after bfp16

parent 4ac8ecb9
......@@ -11,6 +11,8 @@ import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
for fn, typ in it.product(fns, expected.keys()):
x = torch.randn(input_shape, dtype=typ).requires_grad_()
......@@ -101,9 +103,11 @@ class TestBasicCastsBFloat16(_TestBasicCasts):
def tearDown(self):
self.handle._deactivate()
@skipIfRocm
def test_linear_is_bfloat16(self):
self._test_linear(ALWAYS_BFLOAT16)
@skipIfRocm
def test_conv2d_is_bfloat16(self):
self._test_conv2d(ALWAYS_BFLOAT16)
......@@ -227,9 +231,11 @@ class TestTensorCastsBFloat16(_TestTensorCasts):
def tearDown(self):
self.handle._deactivate()
@skipIfRocm
def test_matmul_method_is_bfloat16(self):
self._test_matmul_method(ALWAYS_BFLOAT16)
@skipIfRocm
def test_matmul_op_is_bfloat16(self):
self._test_matmul_op(ALWAYS_BFLOAT16)
......
......@@ -13,6 +13,7 @@ from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try:
import amp_C
......@@ -53,6 +54,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
pass
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -185,6 +187,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
......@@ -346,6 +349,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -541,6 +545,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......
......@@ -41,7 +41,8 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def tearDown(self):
pass
@skipIfRocm
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -167,6 +168,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
......@@ -323,6 +325,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@skipIfRocm
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -513,6 +516,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment