Commit 9297be60 authored by lcskrishna's avatar lcskrishna
Browse files

enable skipped unit tests fused_sgd, multiple_models_and_optimizers

parent ecff271c
...@@ -54,7 +54,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -54,7 +54,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
pass pass
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses1optimizer(self): def test_2models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -187,7 +186,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -187,7 +186,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses1optimizer(self): def test_3models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
...@@ -349,7 +347,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -349,7 +347,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses2optimizers(self): def test_2models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -545,7 +542,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -545,7 +542,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses2optimizers(self): def test_3models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
......
...@@ -44,7 +44,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -44,7 +44,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
@skipIfRocm
def test_2models2losses1optimizer(self): def test_2models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -170,7 +169,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -170,7 +169,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1": if opt_level == "O1":
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses1optimizer(self): def test_3models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
...@@ -327,7 +325,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -327,7 +325,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1": if opt_level == "O1":
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@skipIfRocm
def test_2models2losses2optimizers(self): def test_2models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -518,7 +515,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -518,7 +515,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1": if opt_level == "O1":
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses2optimizers(self): def test_3models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment