infiniccl_cambricon.cc 2.59 KB
Newer Older
wooway777's avatar
wooway777 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include "infiniccl_cambricon.h"

#include "../../utils.h"
#include <cncl.h>
#include <cnrt.h>
#include <iostream>
#include <vector>

#define CHECK_CNCL(API__) CHECK_INTERNAL(API__, CNCL_RET_SUCCESS)

inline cnrtQueue_t getCambriconStream(infinirtStream_t stream) {
    if (stream == nullptr) {
        return (cnrtQueue_t)(0);
    }
    return static_cast<cnrtQueue_t>(stream);
}

inline cnclComm_t getCnclComm(infinicclComm_t comm) {
    return static_cast<cnclComm_t>(comm->comm);
}

inline cnclDataType_t getCnclDtype(infiniDtype_t datatype) {
    switch (datatype) {
    case INFINI_DTYPE_F32:
        return cnclFloat32;
    case INFINI_DTYPE_F16:
        return cnclFloat16;
    default:
        std::cerr << "Unsupported data type: " << datatype << std::endl;
        std::abort();
        return cnclFloat16;
    }
}

inline cnclReduceOp_t getCnclRedOp(infinicclReduceOp_t op) {
    switch (op) {
    case INFINICCL_SUM:
        return cnclSum;
    case INFINICCL_PROD:
        return cnclProd;
    case INFINICCL_MAX:
        return cnclMax;
    case INFINICCL_MIN:
        return cnclMin;
    default:
        std::abort();
        return cnclSum;
    }
}

namespace infiniccl::cambricon {

infiniStatus_t commInitAll(
    infinicclComm_t *comms,
    int ndevice,
    const int *device_ids) {

    std::vector<cnclComm_t> cncl_comms(ndevice);
    std::vector<int> rank_list(ndevice);

    for (int i = 0; i < ndevice; i++) {
        rank_list[i] = i;
        CHECK_INTERNAL(cnrtSetDevice(device_ids[i]), CNRT_RET_SUCCESS);
    }

    CHECK_CNCL(cnclInitComms(cncl_comms.data(), ndevice,
                             (int const *)device_ids, rank_list.data(),
                             ndevice, nullptr));

    for (int i = 0; i < ndevice; i++) {
        comms[i] = new InfinicclComm{INFINI_DEVICE_CAMBRICON, device_ids[i], (void *)(cncl_comms[i])};
    }

    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t commDestroy(infinicclComm_t comm) {
    CHECK_CNCL(cnclFreeComm(getCnclComm(comm)));
    delete comm;
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t allReduce(
    void *sendbuf,
    void *recvbuf,
    size_t count,
    infiniDtype_t datatype,
    infinicclReduceOp_t op,
    infinicclComm_t comm,
    infinirtStream_t stream) {

    if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
        return INFINI_STATUS_BAD_PARAM;
    }

    CHECK_CNCL(cnclAllReduce(sendbuf, recvbuf, count, getCnclDtype(datatype),
                             getCnclRedOp(op), getCnclComm(comm),
                             getCambriconStream(stream)));

    return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::cambricon