Commit 94280d85 authored by wooway777's avatar wooway777
Browse files

issue/434 - added bf16 support for Cambricon MLU

parent f9d16628
...@@ -25,6 +25,8 @@ inline cnclDataType_t getCnclDtype(infiniDtype_t datatype) { ...@@ -25,6 +25,8 @@ inline cnclDataType_t getCnclDtype(infiniDtype_t datatype) {
return cnclFloat32; return cnclFloat32;
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return cnclFloat16; return cnclFloat16;
case INFINI_DTYPE_BF16:
return cnclBfloat16;
default: default:
std::cerr << "Unsupported data type: " << datatype << std::endl; std::cerr << "Unsupported data type: " << datatype << std::endl;
std::abort(); std::abort();
...@@ -89,9 +91,7 @@ infiniStatus_t allReduce( ...@@ -89,9 +91,7 @@ infiniStatus_t allReduce(
infinicclComm_t comm, infinicclComm_t comm,
infinirtStream_t stream) { infinirtStream_t stream) {
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
return INFINI_STATUS_BAD_PARAM;
}
CHECK_CNCL(cnclAllReduce(sendbuf, recvbuf, count, getCnclDtype(datatype), CHECK_CNCL(cnclAllReduce(sendbuf, recvbuf, count, getCnclDtype(datatype),
getCnclRedOp(op), getCnclComm(comm), getCnclRedOp(op), getCnclComm(comm),
...@@ -99,4 +99,5 @@ infiniStatus_t allReduce( ...@@ -99,4 +99,5 @@ infiniStatus_t allReduce(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace infiniccl::cambricon } // namespace infiniccl::cambricon
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment