Unverified Commit 5cfdc014 authored by Peng's avatar Peng Committed by GitHub
Browse files

Merge pull request #11 from lcskrishna/cl/fused-optimizers-bfp16

[FusedOptimizers] Bug fixes in fused optimizers for fp16/bfp16.
parents bdd481d1 9297be60
...@@ -271,7 +271,7 @@ void multi_tensor_sgd_cuda( ...@@ -271,7 +271,7 @@ void multi_tensor_sgd_cuda(
scale); scale);
} }
// Case 5. bfp16, bfp16, bfp16, No // Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 && else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 && weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3) num_tensors == 3)
{ {
......
...@@ -54,7 +54,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -54,7 +54,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
pass pass
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses1optimizer(self): def test_2models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -187,7 +186,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -187,7 +186,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses1optimizer(self): def test_3models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
...@@ -349,7 +347,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -349,7 +347,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses2optimizers(self): def test_2models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -545,7 +542,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -545,7 +542,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses2optimizers(self): def test_3models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
......
...@@ -44,7 +44,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -44,7 +44,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
@skipIfRocm
def test_2models2losses1optimizer(self): def test_2models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -170,7 +169,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -170,7 +169,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1": if opt_level == "O1":
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses1optimizer(self): def test_3models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
...@@ -327,7 +325,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -327,7 +325,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1": if opt_level == "O1":
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@skipIfRocm
def test_2models2losses2optimizers(self): def test_2models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
...@@ -518,7 +515,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -518,7 +515,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1": if opt_level == "O1":
_amp_state.handle._deactivate() _amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses2optimizers(self): def test_3models2losses2optimizers(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment