sddmm.cu 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/sddmm.cu
 * \brief SDDMM C APIs and definitions.
 */
#include <dgl/array.h>
#include "./sddmm.cuh"
#include "./functor.cuh"

namespace dgl {
namespace aten {

#define SWITCH_OP(op, Op, ...)                                      \
  do {                                                              \
    if ((op) == "add") {                                            \
      typedef cuda::binary::Add<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "sub") {                                     \
      typedef cuda::binary::Sub<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "mul") {                                     \
      typedef cuda::binary::Mul<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "div") {                                     \
      typedef cuda::binary::Div<DType> Op;                          \
      { __VA_ARGS__ }                                               \
27
28
    } else if ((op) == "copy_lhs") {                                \
      typedef cuda::binary::CopyLhs<DType> Op;                      \
29
      { __VA_ARGS__ }                                               \
30
31
    } else if ((op) == "copy_rhs") {                                \
      typedef cuda::binary::CopyRhs<DType> Op;                      \
32
33
34
35
36
37
38
39
40
      { __VA_ARGS__ }                                               \
    } else if ((op) == "dot") {                                     \
      typedef cuda::binary::Dot<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else {                                                        \
      LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op;     \
    }                                                               \
  } while (0)

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)
72
73
74
75

/*!
 * \brief CUDA implementation of g-SDDMM on Csr format.
 */
76
template <int XPU, typename IdType, int bits>
77
78
79
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
80
81
82
83
84
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
85
86
87
88
89
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
      });
90
    });
91
92
93
  });
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/*!
 * \brief CUDA implementation of g-SDDMM on heterograph using 
    Csr format.
 */
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<CSRMatrix>& vec_csr,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_eid,
              const std::vector<dgl_type_t>& rhs_eid) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
112
        /* Call SDDMM CUDA kernel for each relation type sequentially */
113
114
115
116
117
        for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
          CSRMatrix csr = vec_csr[etype];
          NDArray lhs = vec_lhs[lhs_eid[etype]];
          NDArray rhs = vec_rhs[rhs_eid[etype]];
          NDArray out = vec_out[etype];
118
119
          cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
            bcast, csr, lhs, rhs, out);
120
121
122
123
124
125
126
        }
      });
    });
  });
}


127
128
129
/*!
 * \brief CUDA implementation of g-SDDMM on Coo format.
 */
130
template <int XPU, typename IdType, int bits>
131
132
133
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
134
135
136
137
138
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
139
140
141
142
143
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
      });
144
    });
145
146
147
  });
}

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

/*!
 * \brief CUDA implementation of g-SDDMM on heterograph using
    Csr format.
 */
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<COOMatrix>& vec_coo,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_eid,
              const std::vector<dgl_type_t>& rhs_eid) {
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        /* Call SDDMM CUDA kernel for each relation type sequentially */
        for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
          COOMatrix coo = vec_coo[etype];
          NDArray lhs = vec_lhs[lhs_eid[etype]];
          NDArray rhs = vec_rhs[rhs_eid[etype]];
          NDArray out = vec_out[etype];
          cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
            bcast, coo, lhs, rhs, out);
        }
      });
    });
  });
}


183
template void SDDMMCsr<kDLGPU, int32_t, 16>(
184
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
185
186
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
187
template void SDDMMCsr<kDLGPU, int64_t, 16>(
188
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
189
190
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
191
template void SDDMMCsr<kDLGPU, int32_t, 32>(
192
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
193
194
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
195
196
197
198
199
200
201
202
203
template void SDDMMCsr<kDLGPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 64>(
204
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
205
206
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
template void SDDMMCsrHetero<kDLGPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);

251
252
253
254
255
256
257
258
259
template void SDDMMCoo<kDLGPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 32>(
260
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
261
262
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
263
template void SDDMMCoo<kDLGPU, int64_t, 32>(
264
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
265
266
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
267
template void SDDMMCoo<kDLGPU, int32_t, 64>(
268
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
269
270
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
271
template void SDDMMCoo<kDLGPU, int64_t, 64>(
272
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
273
274
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
template void SDDMMCooHetero<kDLGPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);

319
320
}  // namespace aten
}  // namespace dgl