cupy_nccl.h 5.01 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#ifndef INCLUDE_GUARD_CUPY_NCCL_H
#define INCLUDE_GUARD_CUPY_NCCL_H

#if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP)

#include "cuda/cupy_nccl.h"

#elif defined(CUPY_USE_HIP)

#include "hip/cupy_rccl.h"

#else // #ifndef CUPY_NO_CUDA

#include "stub/cupy_nccl.h"

#endif

#ifndef NCCL_MAJOR
#define NCCL_MAJOR 1
#define NCCL_MINOR 0
#define NCCL_PATCH 0
#endif

#ifndef NCCL_VERSION_CODE
#define NCCL_VERSION_CODE (NCCL_MAJOR * 1000 + NCCL_MINOR * 100 + NCCL_PATCH)
#endif


#if (NCCL_VERSION_CODE >= 2000)

ncclDataType_t _get_proper_datatype(ncclDataType_t datatype) {
    return datatype;
}

#else // #if (NCCL_VERSION_CODE >= 2000)

#define NCCL_CHAR_V1 ncclChar
#define NCCL_INT_V1 ncclInt
#define NCCL_HALF_V1 ncclHalf
#define NCCL_FLOAT_V1 ncclFloat
#define NCCL_DOUBLE_v1 ncclDouble
#define NCCL_INT64_v1 ncclInt64
#define NCCL_UINT64_v1 ncclUint64
#define NCCL_INVALID_TYPE_V1 nccl_NUM_TYPES

static const ncclDataType_t TYPE2TYPE_V1[] = {
    NCCL_CHAR_V1,         // ncclInt8, ncclChar
    NCCL_INVALID_TYPE_V1, // ncclUint8
    NCCL_INT_V1,          // ncclInt32, ncclInt
    NCCL_INVALID_TYPE_V1, // ncclUint32
    NCCL_INT64_v1,        // ncclInt64
    NCCL_UINT64_v1,       // ncclUint64
    NCCL_HALF_V1,         // ncclFloat16, ncclHalf
    NCCL_FLOAT_V1,        // ncclFloat32, ncclFloat
    NCCL_DOUBLE_v1        // ncclFloat64, ncclDouble
};

ncclDataType_t _get_proper_datatype(ncclDataType_t datatype) {
    return TYPE2TYPE_V1[datatype];
}

#ifndef CUPY_NO_CUDA
ncclResult_t ncclGroupStart() {
    return ncclSuccess;
}

ncclResult_t ncclGroupEnd() {
    return ncclSuccess;
}
#endif // #ifndef CUPY_NO_CUDA
#endif // #if (NCCL_VERSION_CODE < 2000)

#if (NCCL_VERSION_CODE < 2200)
// New function in 2.2
ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count,
			   ncclDataType_t datatype, int root, ncclComm_t comm,
			   cudaStream_t stream) {
    return ncclSuccess;
}
#endif // #if (NCCL_VERSION_CODE < 2200)

#if (NCCL_VERSION_CODE < 2304)

ncclResult_t ncclGetVersion(int *version) {
    *version = 0;
    return ncclSuccess;
}

#endif // #if (NCCL_VERSION_CODE < 2304)

ncclResult_t _ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
                            ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
                            cudaStream_t stream) {
    ncclDataType_t _datatype = _get_proper_datatype(datatype);
    return ncclAllReduce(sendbuff, recvbuff, count, _datatype, op, comm, stream);
}


ncclResult_t _ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
                         ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm,
                         cudaStream_t stream) {
    ncclDataType_t _datatype = _get_proper_datatype(datatype);
    return ncclReduce(sendbuff, recvbuff, count, _datatype, op, root, comm, stream);
}


ncclResult_t _ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count,
			    ncclDataType_t datatype, int root, ncclComm_t comm,
			    cudaStream_t stream) {
    ncclDataType_t _datatype = _get_proper_datatype(datatype);
    return ncclBroadcast(sendbuff, recvbuff, count, _datatype, root, comm,  stream);
}


ncclResult_t _ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root,
                        ncclComm_t comm, cudaStream_t stream) {
    ncclDataType_t _datatype = _get_proper_datatype(datatype);
    return ncclBcast(buff, count, _datatype, root, comm,  stream);
}


ncclResult_t _ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
                                ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
                                cudaStream_t stream) {
    ncclDataType_t _datatype = _get_proper_datatype(datatype);
    return ncclReduceScatter(sendbuff, recvbuff, recvcount, _datatype, op, comm, stream);
}


ncclResult_t _ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
                            ncclDataType_t datatype, ncclComm_t comm,
                            cudaStream_t stream) {
    ncclDataType_t _datatype = _get_proper_datatype(datatype);
#if (NCCL_VERSION_CODE >= 2000)
    return ncclAllGather(sendbuff, recvbuff, sendcount, _datatype, comm, stream);
#else
    return ncclAllGather(sendbuff, sendcount, _datatype, recvbuff, comm, stream);
#endif // #if (NCCL_VERSION_CODE < 2000)
}

#if (NCCL_VERSION_CODE < 2400)
// New functions in 2.4
#define UNUSED(x) ((void)x)

ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError) {
  UNUSED(comm);
  UNUSED(asyncError);
  return ncclSuccess;
}

void ncclCommAbort(ncclComm_t comm) {
  UNUSED(comm);
}
#endif

#if (NCCL_VERSION_CODE < 2700)
// New functions in 2.7
ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype,
                      int peer, ncclComm_t comm, cudaStream_t stream) {
    return ncclSuccess;
}

ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype,
                      int peer, ncclComm_t comm, cudaStream_t stream) {
    return ncclSuccess;
}
#endif

#endif // #ifndef INCLUDE_GUARD_CUPY_NCCL_H