Commit 81093e0b authored by PanZezhong1725's avatar PanZezhong1725
Browse files

issue/434 nccl support bf16

parent 9ad23fad
...@@ -22,6 +22,8 @@ inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) { ...@@ -22,6 +22,8 @@ inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) {
return ncclFloat; return ncclFloat;
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return ncclHalf; return ncclHalf;
case INFINI_DTYPE_BF16:
return ncclBfloat16;
default: default:
std::abort(); std::abort();
return ncclHalf; return ncclHalf;
...@@ -82,9 +84,7 @@ infiniStatus_t allReduce( ...@@ -82,9 +84,7 @@ infiniStatus_t allReduce(
infinicclComm_t comm, infinicclComm_t comm,
infinirtStream_t stream) { infinirtStream_t stream) {
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
return INFINI_STATUS_BAD_PARAM;
}
CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype), CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype),
getNcclRedOp(op), getNcclComm(comm), getCudaStream(stream))); getNcclRedOp(op), getNcclComm(comm), getCudaStream(stream)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment