Unverified Commit 5cfdc014 authored by Peng's avatar Peng Committed by GitHub
Browse files

Merge pull request #11 from lcskrishna/cl/fused-optimizers-bfp16

[FusedOptimizers] Bug fixes in fused optimizers for fp16/bfp16.
parents bdd481d1 9297be60
......@@ -271,7 +271,7 @@ void multi_tensor_sgd_cuda(
scale);
}
// Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 &&
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
......
......@@ -54,7 +54,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
pass
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -187,7 +186,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
......@@ -349,7 +347,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -545,7 +542,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......
......@@ -44,7 +44,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def tearDown(self):
pass
@skipIfRocm
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -170,7 +169,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
......@@ -327,7 +325,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@skipIfRocm
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......@@ -518,7 +515,6 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@skipIfRocm
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment