Commit 98a64039 authored by lcskrishna's avatar lcskrishna
Browse files

bug fixes in sgd kernel in bfp16 bringup

parent b2da92fc
......@@ -271,7 +271,7 @@ void multi_tensor_sgd_cuda(
scale);
}
// Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 &&
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment