Commit 93d261ed authored by zhushuang's avatar zhushuang
Browse files

issue/1021 - feat: support bf16 in infiniccl on moore_gpu_arch mp_31 with mccl

parent 718b18cf
...@@ -23,6 +23,12 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) { ...@@ -23,6 +23,12 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) {
return mcclFloat; return mcclFloat;
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return mcclHalf; return mcclHalf;
#if MARCH_TYPE == 310
case INFINI_DTYPE_BF16:
return mcclBfloat16;
#endif
default: default:
std::abort(); std::abort();
return mcclHalf; return mcclHalf;
...@@ -83,9 +89,16 @@ infiniStatus_t allReduce( ...@@ -83,9 +89,16 @@ infiniStatus_t allReduce(
infinicclComm_t comm, infinicclComm_t comm,
infinirtStream_t stream) { infinirtStream_t stream) {
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { #if MARCH_TYPE == 310
return INFINI_STATUS_BAD_PARAM; CHECK_DTYPE(datatype,
} INFINI_DTYPE_F32,
INFINI_DTYPE_F16,
INFINI_DTYPE_BF16);
#else
CHECK_DTYPE(datatype,
INFINI_DTYPE_F32,
INFINI_DTYPE_F16);
#endif
CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype), CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype),
getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream))); getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream)));
......
...@@ -180,6 +180,12 @@ option("moore-gpu") ...@@ -180,6 +180,12 @@ option("moore-gpu")
set_description("Whether to compile implementations for Moore Threads GPU") set_description("Whether to compile implementations for Moore Threads GPU")
option_end() option_end()
option("moore-gpu-arch")
set_default("mp_31")
set_showmenu(true)
set_description("Set Moore GPU architecture (e.g. mp_31)")
option_end()
if has_config("moore-gpu") then if has_config("moore-gpu") then
add_defines("ENABLE_MOORE_API") add_defines("ENABLE_MOORE_API")
includes("xmake/moore.lua") includes("xmake/moore.lua")
......
...@@ -16,7 +16,22 @@ rule("mu") ...@@ -16,7 +16,22 @@ rule("mu")
local mcc = MUSA_ROOT .. "/bin/mcc" local mcc = MUSA_ROOT .. "/bin/mcc"
local includedirs = table.concat(target:get("includedirs"), " ") local includedirs = table.concat(target:get("includedirs"), " ")
local args = {"-c", sourcefile, "-o", objectfile, "-I" .. MUSA_ROOT .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} local args = {
"-c", sourcefile,
"-o", objectfile,
"-I" .. MUSA_ROOT .. "/include",
"-O3",
"-fPIC",
"-Wall",
"-std=c++17",
"-pthread"
}
local moore_gpu_arch = get_config("moore-gpu-arch")
if moore_gpu_arch == "mp_31" then
table.insert(args, 1, "--cuda-gpu-arch=mp_31")
end
for _, includedir in ipairs(target:get("includedirs")) do for _, includedir in ipairs(target:get("includedirs")) do
table.insert(args, "-I" .. includedir) table.insert(args, "-I" .. includedir)
end end
...@@ -76,6 +91,12 @@ target("infiniccl-moore") ...@@ -76,6 +91,12 @@ target("infiniccl-moore")
if has_config("ccl") then if has_config("ccl") then
add_links("libmccl.so") add_links("libmccl.so")
add_files("../src/infiniccl/moore/*.cc") add_files("../src/infiniccl/moore/*.cc")
-- Moore GPU arch with mp_31 support mcclBfloat16 in MCCL
if get_config("moore-gpu-arch") == "mp_31" then
add_defines("MARCH_TYPE=310")
add_cxxflags("-Wno-unused-function")
end
end end
set_languages("cxx17") set_languages("cxx17")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment