sddmm.cuh 14.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/sddmm.cuh
 * @brief SDDMM CUDA kernel function header.
5
6
7
8
9
 */
#ifndef DGL_ARRAY_CUDA_SDDMM_CUH_
#define DGL_ARRAY_CUDA_SDDMM_CUH_

#include <dgl/bcast.h>
10
11
12
13
14

#include "../../runtime/cuda/cuda_common.h"
#include "../selector.h"
#include "./functor.cuh"
#include "./utils.h"
15
#include "atomic.cuh"
16
#include "bf16.cuh"
17
18
19
#include "fp16.cuh"
#include "functor.cuh"
#include "macro.cuh"
20
21
22
23
24
25
26
27

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

28
#define SWITCH_OP(op, Op, ...)                                        \
29
  do {                                                                \
30
31
32
33
34
35
36
37
    if ((op) == "add") {                                              \
      typedef cuda::binary::Add<DType> Op;                            \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "sub") {                                       \
      typedef cuda::binary::Sub<DType> Op;                            \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "mul") {                                       \
      typedef cuda::binary::Mul<DType> Op;                            \
38
      { __VA_ARGS__ }                                                 \
39
40
    } else if ((op) == "div") {                                       \
      typedef cuda::binary::Div<DType> Op;                            \
41
      { __VA_ARGS__ }                                                 \
42
43
44
45
46
47
48
49
    } else if ((op) == "copy_lhs") {                                  \
      typedef cuda::binary::CopyLhs<DType> Op;                        \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "copy_rhs") {                                  \
      typedef cuda::binary::CopyRhs<DType> Op;                        \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "dot") {                                       \
      typedef cuda::binary::Dot<DType> Op;                            \
50
51
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
52
      LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
53
54
55
    }                                                                 \
  } while (0)

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#define SWITCH_RHS(rhs_target, RhsTarget, ...)             \
  do {                                                     \
    if ((rhs_target) == 0) {                               \
      constexpr int RhsTarget = 0;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 1) {                        \
      constexpr int RhsTarget = 1;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 2) {                        \
      constexpr int RhsTarget = 2;                         \
      { __VA_ARGS__ }                                      \
    } else {                                               \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
    }                                                      \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
  do {                                                                   \
    if ((lhs_target) == 0) {                                             \
      constexpr int LhsTarget = 0;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 1) {                                      \
      constexpr int LhsTarget = 1;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 2) {                                      \
      constexpr int LhsTarget = 2;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else {                                                             \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);               \
    }                                                                    \
86
87
  } while (0)

88
89
constexpr unsigned int full_mask = 0xffffffff;

90
/**
91
92
 * @brief CUDA kernel of g-SDDMM on Coo format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
93
 *       is responsible for the computation on different edges. Threadblocks
94
95
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension.
96
 */
97
98
99
template <
    typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
    bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
100
__global__ void SDDMMCooKernel(
101
102
103
104
105
106
    const DType* __restrict__ lhs, const DType* __restrict__ rhs,
    DType* __restrict__ out, const Idx* __restrict__ row,
    const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
    int64_t M, int64_t E, int64_t reduce_size,
    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
107
108
109
110
111
112
113
  // SDDMM with COO.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
114
115
116
117
118
119
120
121
    const DType* lhsoff =
        BinaryOp::use_lhs
            ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
            : nullptr;
    const DType* rhsoff =
        BinaryOp::use_rhs
            ? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
            : nullptr;
122
123
124
125
    DType* outoff = out + eid * out_len;
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int stride_x = blockDim.x * gridDim.x;
    while (tx < out_len) {
126
127
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
128
      DType val = BinaryOp::Call(
129
          lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
130
131
132
133
134
135
136
137
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

138
/**
139
140
 * @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree
 * reduction.
141
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
142
 *       is responsible for the computation on different edges. Threadblocks
143
144
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension.
145
 */
146
147
148
template <
    typename Idx, typename DType, bool UseBcast = false, bool UseIdx = false,
    int LhsTarget = 0, int RhsTarget = 2>
149
__global__ void SDDMMCooTreeReduceKernel(
150
151
152
153
154
155
    const DType* __restrict__ lhs, const DType* __restrict__ rhs,
    DType* __restrict__ out, const Idx* __restrict__ row,
    const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
    int64_t M, int64_t E, int64_t reduce_size,
    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
156
157
158
159
160
  Idx ty = blockIdx.x * blockDim.y + threadIdx.y;
  if (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
161
162
163
164
    const DType* lhsoff =
        lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
    const DType* rhsoff =
        rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
165
166
    DType* outoff = out + eid * out_len;
    int tx = threadIdx.x;  // tx < 32
167
168
    for (int i = blockIdx.y; i < out_len;
         i += gridDim.y) {  // over output feature dimension
169
170
      const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
      const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
171
      DType val = reduce::Sum<Idx, DType>::zero();
Zihao Ye's avatar
Zihao Ye committed
172
      for (int j = tx; j < reduce_size; j += 64) {
173
174
        val += lhsoff[lhs_add * reduce_size + j] *
               rhsoff[rhs_add * reduce_size + j];
Zihao Ye's avatar
Zihao Ye committed
175
        if (j + 32 < reduce_size)
176
177
          val += lhsoff[lhs_add * reduce_size + j + 32] *
                 rhsoff[rhs_add * reduce_size + j + 32];
Zihao Ye's avatar
Zihao Ye committed
178
      }
179
180
181
#pragma unroll
      for (int offset = 16; offset > 0; offset /= 2)
        val += __shfl_down_sync(full_mask, val, offset);
182
      if (tx == 0) outoff[i] = val;
183
184
185
186
    }
  }
}

187
188
// Binary search the row_offsets to find the source node of the edge id.
template <typename Idx>
189
190
__device__ __forceinline__ Idx
BinarySearchSrc(const Idx* array, Idx length, Idx eid) {
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
  Idx lo = 0, hi = length - 1;
  while (lo < hi) {
    Idx mid = (lo + hi) >> 1;
    if (_ldg(array + mid) <= eid) {
      lo = mid + 1;
    } else {
      hi = mid;
    }
  }
  // INVARIANT: lo == hi
  if (_ldg(array + hi) == eid) {
    return hi;
  } else {
    return hi - 1;
  }
}

208
/**
209
210
 * @brief CUDA kernel of g-SDDMM on Csr format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
211
 *       is responsible for the computation on different edges. Threadblocks
212
213
214
215
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension. To efficiently find the source node idx and
 * destination node index of an given edge on Csr format, it uses binary search
 * (time complexity O(log N)).
216
 */
217
218
219
template <
    typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
    bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
220
__global__ void SDDMMCsrKernel(
221
222
223
224
225
226
    const DType* __restrict__ lhs, const DType* __restrict__ rhs,
    DType* __restrict__ out, const Idx* __restrict__ indptr,
    const Idx* __restrict__ indices, const Idx* __restrict__ edge_map,
    int64_t N, int64_t M, int64_t E, int64_t reduce_size,
    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
227
228
229
230
231
232
233
234
235
  // SDDMM with Csr.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = BinarySearchSrc<Idx>(indptr, N + 1, ty);
    const Idx dst = _ldg(indices + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
236
237
238
239
240
241
242
243
    const DType* lhsoff =
        BinaryOp::use_lhs
            ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
            : nullptr;
    const DType* rhsoff =
        BinaryOp::use_rhs
            ? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
            : nullptr;
244
245
    DType* outoff = out + eid * out_len;
    while (tx < out_len) {
246
247
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
248
      DType val = BinaryOp::Call(
249
          lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
250
251
252
253
254
255
256
257
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

258
/**
259
260
261
262
263
264
 * @brief CUDA implementation of g-SDDMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param lhs The left hand side operand feature.
 * @param rhs The right hand size operand feature.
 * @param out The result feature on edges.
265
 */
266
267
268
template <
    typename Idx, typename DType, typename Op, int LhsTarget = 0,
    int RhsTarget = 2>
269
void SDDMMCoo(
270
    const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
271
    NDArray out) {
272
273
274
275
276
277
  const Idx* row = coo.row.Ptr<Idx>();
  const Idx* col = coo.col.Ptr<Idx>();
  const Idx* edge_map = coo.data.Ptr<Idx>();
  const DType* lhs_data = lhs.Ptr<DType>();
  const DType* rhs_data = rhs.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
278
  cudaStream_t stream = runtime::getCurrentCUDAStream();
279

280
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
281
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
282
283
284
285
286
  int64_t reduce_dim = bcast.reduce_size;

  const int64_t nnz = coo.row->shape[0];
  const bool use_idx = !IsNullArray(coo.data);

287
288
289
290
291
292
293
294
  if (std::is_same<Op, binary::Dot<DType> >::value && reduce_dim >= 32) {
    const int ntx = 32;  // on feature dimension
    const int nty = 8;   // on out dimension
    const int nbx = (nnz + nty - 1) / nty;
    const int nby = FindNumBlocks<'y'>(len);
    const dim3 nblks(nbx, nby);
    const dim3 nthrs(ntx, nty);
    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
295
      CUDA_KERNEL_CALL(
296
297
298
299
300
          (SDDMMCooTreeReduceKernel<
              Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
          nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
          edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
          rhs_off, lhs_len, rhs_len, len);
301
    });
302
303
304
305
306
307
308
309
  } else {
    const int ntx = FindNumThreads(len);
    const int nty = CUDA_MAX_NUM_THREADS / ntx;
    const int nbx = (len + ntx - 1) / ntx;
    const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);
    const dim3 nblks(nbx, nby);
    const dim3 nthrs(ntx, nty);
    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
310
311
312
313
314
315
      CUDA_KERNEL_CALL(
          (SDDMMCooKernel<
              Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
          nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
          edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
          rhs_off, lhs_len, rhs_len, len);
316
317
    });
  }
318
319
}

320
/**
321
322
323
324
325
326
 * @brief CUDA implementation of g-SDDMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param lhs The left hand side operand feature.
 * @param rhs The right hand size operand feature.
 * @param out The result feature on edges.
327
 */
328
329
330
template <
    typename Idx, typename DType, typename Op, int LhsTarget = 0,
    int RhsTarget = 2>
331
void SDDMMCsr(
332
    const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
333
    NDArray out) {
334
335
336
337
338
339
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const Idx* edge_map = csr.data.Ptr<Idx>();
  const DType* lhs_data = lhs.Ptr<DType>();
  const DType* rhs_data = rhs.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
340
  cudaStream_t stream = runtime::getCurrentCUDAStream();
341
342
  int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];

343
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
344
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
345
346
347
348
349
350
351
352
353
354
  int64_t reduce_dim = bcast.reduce_size;

  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

355
  BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
356
357
358
359
360
    CUDA_KERNEL_CALL(
        (SDDMMCsrKernel<
            Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
        nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, indptr, indices,
        edge_map, N, M, E, reduce_dim, lhs_off, rhs_off, lhs_len, rhs_len, len);
361
362
363
364
365
366
367
  });
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

368
#endif  // DGL_ARRAY_CUDA_SDDMM_CUH_