Commit 4ab20362 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

issue/247 接入昇腾通信库

parent 4799ddbf
...@@ -92,29 +92,32 @@ void *testAllReduceThread(void *arg) { ...@@ -92,29 +92,32 @@ void *testAllReduceThread(void *arg) {
ThreadArgs *args = (ThreadArgs *)arg; ThreadArgs *args = (ThreadArgs *)arg;
*(args->result) = 1; *(args->result) = 1;
TEST_INFINI_THREAD(infinirtSetDevice(args->device_type, args->device_id)); TEST_INFINI_THREAD(infinirtSetDevice(args->device_type, args->device_id));
infinirtStream_t stream;
TEST_INFINI_THREAD(infinirtStreamCreate(&stream));
void *output = std::malloc(args->count * infiniSizeOf(args->dtype)); void *output = std::malloc(args->count * infiniSizeOf(args->dtype));
std::memset(output, 0, args->count * infiniSizeOf(args->dtype)); std::memset(output, 0, args->count * infiniSizeOf(args->dtype));
void *buf; void *buf;
TEST_INFINI_THREAD(infinirtMalloc(&buf, args->count * infiniSizeOf(args->dtype))); TEST_INFINI_THREAD(infinirtMalloc(&buf, args->count * infiniSizeOf(args->dtype)));
TEST_INFINI_THREAD(infinirtMemcpy(buf, args->data, args->count * infiniSizeOf(args->dtype), INFINIRT_MEMCPY_H2D)); TEST_INFINI_THREAD(infinirtMemcpy(buf, args->data, args->count * infiniSizeOf(args->dtype), INFINIRT_MEMCPY_H2D));
TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, NULL)); TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, stream));
TEST_INFINI_THREAD(infinirtDeviceSynchronize()); TEST_INFINI_THREAD(infinirtDeviceSynchronize());
TEST_INFINI_THREAD(infinirtMemcpy(output, buf, args->count * infiniSizeOf(args->dtype), INFINIRT_MEMCPY_D2H)); TEST_INFINI_THREAD(infinirtMemcpy(output, buf, args->count * infiniSizeOf(args->dtype), INFINIRT_MEMCPY_D2H));
if (checkData(output, args->ans, args->dtype, args->count) != 0) { if (checkData(output, args->ans, args->dtype, args->count) != 0) {
std::free(output); std::free(output);
infinirtFree(buf); infinirtFree(buf);
infinirtStreamDestroy(stream);
return nullptr; return nullptr;
} }
for (size_t i = 0; i < WARM_UPS; i++) { for (size_t i = 0; i < WARM_UPS; i++) {
TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, NULL)); TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, stream));
} }
TEST_INFINI_THREAD(infinirtDeviceSynchronize()); TEST_INFINI_THREAD(infinirtDeviceSynchronize());
// measure time // measure time
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < ITERATIONS; i++) { for (size_t i = 0; i < ITERATIONS; i++) {
TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, NULL)); TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, stream));
} }
TEST_INFINI_THREAD(infinirtDeviceSynchronize()); TEST_INFINI_THREAD(infinirtDeviceSynchronize());
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
...@@ -125,6 +128,7 @@ void *testAllReduceThread(void *arg) { ...@@ -125,6 +128,7 @@ void *testAllReduceThread(void *arg) {
std::free(output); std::free(output);
infinirtFree(buf); infinirtFree(buf);
infinirtStreamDestroy(stream);
return nullptr; return nullptr;
} }
......
...@@ -60,6 +60,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) { ...@@ -60,6 +60,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
ParsedArgs args = parseArgs(argc, argv); ParsedArgs args = parseArgs(argc, argv);
int ndevice = 0; int ndevice = 0;
infinirtInit();
if (infinirtGetDeviceCount(args.device_type, &ndevice) != INFINI_STATUS_SUCCESS) { if (infinirtGetDeviceCount(args.device_type, &ndevice) != INFINI_STATUS_SUCCESS) {
std::cout << "Failed to get device count" << std::endl; std::cout << "Failed to get device count" << std::endl;
return -1; return -1;
......
#include "infiniccl_ascend.h"
#include "../../utils.h"
#include <acl/acl.h>
#include <hccl.h>
#include <iostream>
#include <vector>
#define CHECK_HCCL(API__) CHECK_INTERNAL(API__, HCCL_SUCCESS)
inline aclrtStream getAscendStream(infinirtStream_t stream) {
if (stream == nullptr) {
return 0;
}
return static_cast<aclrtStream>(stream);
}
inline HcclComm getHcclComm(infinicclComm_t comm) {
return static_cast<HcclComm>(comm->comm);
}
inline HcclDataType getAscendDtype(infiniDtype_t datatype) {
switch (datatype) {
case INFINI_DTYPE_F32:
return HCCL_DATA_TYPE_FP32;
case INFINI_DTYPE_F16:
return HCCL_DATA_TYPE_FP16;
default:
std::cerr << "Unsupported data type: " << datatype << std::endl;
std::abort();
return HCCL_DATA_TYPE_FP16;
}
}
inline HcclReduceOp getHcclRedOp(infinicclReduceOp_t op) {
switch (op) {
case INFINICCL_SUM:
return HCCL_REDUCE_SUM;
case INFINICCL_PROD:
return HCCL_REDUCE_PROD;
case INFINICCL_MAX:
return HCCL_REDUCE_MAX;
case INFINICCL_MIN:
return HCCL_REDUCE_MIN;
default:
std::abort();
return HCCL_REDUCE_SUM;
}
}
namespace infiniccl::ascend {
infiniStatus_t commInitAll(
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
// Ascend requires all devices to be initialized before calling HcclCommInitAll.
for (int i = ndevice - 1; i >= 0; i--) {
aclrtSetDevice(device_ids[i]);
}
std::vector<HcclComm> hccl_comms(ndevice);
CHECK_HCCL(HcclCommInitAll(ndevice, (int32_t *)device_ids, hccl_comms.data()));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_ASCEND, device_ids[i], (void *)(hccl_comms[i])};
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t commDestroy(infinicclComm_t comm) {
CHECK_HCCL(HcclCommDestroy(getHcclComm(comm)));
delete comm;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t allReduce(
void *sendbuf,
void *recvbuf,
size_t count,
infiniDtype_t datatype,
infinicclReduceOp_t op,
infinicclComm_t comm,
infinirtStream_t stream) {
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_HCCL(HcclAllReduce(sendbuf, recvbuf, (uint64_t)count,
getAscendDtype(datatype), getHcclRedOp(op),
getHcclComm(comm), getAscendStream(stream)));
return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::ascend
#ifndef INFINICCL_ASCEND_H_
#define INFINICCL_ASCEND_H_
#include "../infiniccl_impl.h"
#if defined(ENABLE_ASCEND_API) && defined(ENABLE_CCL)
INFINICCL_DEVICE_API_IMPL(ascend)
#else
INFINICCL_DEVICE_API_NOOP(ascend)
#endif
#endif /* INFINICCL_ASCEND_H_ */
#include "infiniccl.h" #include "infiniccl.h"
#include "./ascend/infiniccl_ascend.h"
#include "./cuda/infiniccl_cuda.h" #include "./cuda/infiniccl_cuda.h"
__C infiniStatus_t infinicclCommInitAll( __C infiniStatus_t infinicclCommInitAll(
...@@ -14,6 +15,7 @@ __C infiniStatus_t infinicclCommInitAll( ...@@ -14,6 +15,7 @@ __C infiniStatus_t infinicclCommInitAll(
switch (device_type) { switch (device_type) {
COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda) COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda)
COMM_INIT_ALL(INFINI_DEVICE_ASCEND, ascend)
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -32,6 +34,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) { ...@@ -32,6 +34,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
switch (comm->device_type) { switch (comm->device_type) {
COMM_DESTROY(INFINI_DEVICE_NVIDIA, cuda) COMM_DESTROY(INFINI_DEVICE_NVIDIA, cuda)
COMM_DESTROY(INFINI_DEVICE_ASCEND, ascend)
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -57,6 +60,7 @@ __C infiniStatus_t infinicclAllReduce( ...@@ -57,6 +60,7 @@ __C infiniStatus_t infinicclAllReduce(
switch (comm->device_type) { switch (comm->device_type) {
ALL_REDUCE(INFINI_DEVICE_NVIDIA, cuda) ALL_REDUCE(INFINI_DEVICE_NVIDIA, cuda)
ALL_REDUCE(INFINI_DEVICE_ASCEND, ascend)
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
......
...@@ -242,6 +242,9 @@ target("infiniccl") ...@@ -242,6 +242,9 @@ target("infiniccl")
if has_config("nv-gpu") then if has_config("nv-gpu") then
add_deps("infiniccl-cuda") add_deps("infiniccl-cuda")
end end
if has_config("ascend-npu") then
add_deps("infiniccl-ascend")
end
set_languages("cxx17") set_languages("cxx17")
......
...@@ -63,3 +63,18 @@ target("infinirt-ascend") ...@@ -63,3 +63,18 @@ target("infinirt-ascend")
add_files("$(projectdir)/src/infinirt/ascend/*.cc") add_files("$(projectdir)/src/infinirt/ascend/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC") add_cxflags("-lstdc++ -Wall -Werror -fPIC")
target_end() target_end()
target("infiniccl-ascend")
set_kind("static")
add_deps("infinirt")
add_deps("infini-utils")
set_warnings("all", "error")
set_languages("cxx17")
on_install(function (target) end)
if has_config("ccl") then
add_includedirs(ASCEND_HOME .. "/include/hccl")
add_links("libhccl.so")
add_files("../src/infiniccl/ascend/*.cc")
add_cxflags("-lstdc++ -fPIC")
end
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment