infiniccl_ascend.cc 2.55 KB
Newer Older
Pan Zezhong's avatar
Pan Zezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include "infiniccl_ascend.h"

#include "../../utils.h"

#include <acl/acl.h>
#include <hccl.h>

#include <iostream>
#include <vector>

#define CHECK_HCCL(API__) CHECK_INTERNAL(API__, HCCL_SUCCESS)

inline aclrtStream getAscendStream(infinirtStream_t stream) {
    if (stream == nullptr) {
        return 0;
    }
    return static_cast<aclrtStream>(stream);
}

inline HcclComm getHcclComm(infinicclComm_t comm) {
    return static_cast<HcclComm>(comm->comm);
}

inline HcclDataType getAscendDtype(infiniDtype_t datatype) {
    switch (datatype) {
    case INFINI_DTYPE_F32:
        return HCCL_DATA_TYPE_FP32;
    case INFINI_DTYPE_F16:
        return HCCL_DATA_TYPE_FP16;
zhangyunze's avatar
zhangyunze committed
30
31
    case INFINI_DTYPE_BF16:
        return HCCL_DATA_TYPE_BFP16;
Pan Zezhong's avatar
Pan Zezhong committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    default:
        std::cerr << "Unsupported data type: " << datatype << std::endl;
        std::abort();
        return HCCL_DATA_TYPE_FP16;
    }
}

inline HcclReduceOp getHcclRedOp(infinicclReduceOp_t op) {
    switch (op) {
    case INFINICCL_SUM:
        return HCCL_REDUCE_SUM;
    case INFINICCL_PROD:
        return HCCL_REDUCE_PROD;
    case INFINICCL_MAX:
        return HCCL_REDUCE_MAX;
    case INFINICCL_MIN:
        return HCCL_REDUCE_MIN;
    default:
        std::abort();
        return HCCL_REDUCE_SUM;
    }
}

namespace infiniccl::ascend {

infiniStatus_t commInitAll(
    infinicclComm_t *comms,
    int ndevice,
    const int *device_ids) {
    // Ascend requires all devices to be initialized before calling HcclCommInitAll.
    for (int i = ndevice - 1; i >= 0; i--) {
        aclrtSetDevice(device_ids[i]);
    }

    std::vector<HcclComm> hccl_comms(ndevice);
    CHECK_HCCL(HcclCommInitAll(ndevice, (int32_t *)device_ids, hccl_comms.data()));

    for (int i = 0; i < ndevice; i++) {
        comms[i] = new InfinicclComm{INFINI_DEVICE_ASCEND, device_ids[i], (void *)(hccl_comms[i])};
    }

    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t commDestroy(infinicclComm_t comm) {
    CHECK_HCCL(HcclCommDestroy(getHcclComm(comm)));
    delete comm;
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t allReduce(
    void *sendbuf,
    void *recvbuf,
    size_t count,
    infiniDtype_t datatype,
    infinicclReduceOp_t op,
    infinicclComm_t comm,
    infinirtStream_t stream) {

zhangyunze's avatar
zhangyunze committed
91
    CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
Pan Zezhong's avatar
Pan Zezhong committed
92
93
94
95
96
97
98
99

    CHECK_HCCL(HcclAllReduce(sendbuf, recvbuf, (uint64_t)count,
                             getAscendDtype(datatype), getHcclRedOp(op),
                             getHcclComm(comm), getAscendStream(stream)));

    return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::ascend