sddmm.cuh 15 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/sddmm.cuh
 * @brief SDDMM CUDA kernel function header.
7
8
9
10
11
 */
#ifndef DGL_ARRAY_CUDA_SDDMM_CUH_
#define DGL_ARRAY_CUDA_SDDMM_CUH_

#include <dgl/bcast.h>
12
13
14

#include "../../runtime/cuda/cuda_common.h"
#include "../selector.h"
sangwzh's avatar
sangwzh committed
15
16
#include "functor.cuh"
#include "utils.h"
17
#include "atomic.cuh"
18
#include "bf16.cuh"
19
20
21
#include "fp16.cuh"
#include "functor.cuh"
#include "macro.cuh"
22
23
24
25
26
27
28
29

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

30
#define SWITCH_OP(op, Op, ...)                                        \
31
  do {                                                                \
32
33
34
35
36
37
38
39
    if ((op) == "add") {                                              \
      typedef cuda::binary::Add<DType> Op;                            \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "sub") {                                       \
      typedef cuda::binary::Sub<DType> Op;                            \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "mul") {                                       \
      typedef cuda::binary::Mul<DType> Op;                            \
40
      { __VA_ARGS__ }                                                 \
41
42
    } else if ((op) == "div") {                                       \
      typedef cuda::binary::Div<DType> Op;                            \
43
      { __VA_ARGS__ }                                                 \
44
45
46
47
48
49
50
51
    } else if ((op) == "copy_lhs") {                                  \
      typedef cuda::binary::CopyLhs<DType> Op;                        \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "copy_rhs") {                                  \
      typedef cuda::binary::CopyRhs<DType> Op;                        \
      { __VA_ARGS__ }                                                 \
    } else if ((op) == "dot") {                                       \
      typedef cuda::binary::Dot<DType> Op;                            \
52
53
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
54
      LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
55
56
57
    }                                                                 \
  } while (0)

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#define SWITCH_RHS(rhs_target, RhsTarget, ...)             \
  do {                                                     \
    if ((rhs_target) == 0) {                               \
      constexpr int RhsTarget = 0;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 1) {                        \
      constexpr int RhsTarget = 1;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 2) {                        \
      constexpr int RhsTarget = 2;                         \
      { __VA_ARGS__ }                                      \
    } else {                                               \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
    }                                                      \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
  do {                                                                   \
    if ((lhs_target) == 0) {                                             \
      constexpr int LhsTarget = 0;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 1) {                                      \
      constexpr int LhsTarget = 1;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 2) {                                      \
      constexpr int LhsTarget = 2;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else {                                                             \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);               \
    }                                                                    \
88
89
  } while (0)

90
91
constexpr unsigned int full_mask = 0xffffffff;

92
/**
93
94
 * @brief CUDA kernel of g-SDDMM on Coo format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
95
 *       is responsible for the computation on different edges. Threadblocks
96
97
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension.
98
 */
99
100
101
template <
    typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
    bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
102
__global__ void SDDMMCooKernel(
103
104
105
106
107
108
    const DType* __restrict__ lhs, const DType* __restrict__ rhs,
    DType* __restrict__ out, const Idx* __restrict__ row,
    const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
    int64_t M, int64_t E, int64_t reduce_size,
    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
109
110
111
112
113
114
115
  // SDDMM with COO.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
116
117
118
119
120
121
122
123
    const DType* lhsoff =
        BinaryOp::use_lhs
            ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
            : nullptr;
    const DType* rhsoff =
        BinaryOp::use_rhs
            ? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
            : nullptr;
124
125
126
127
    DType* outoff = out + eid * out_len;
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int stride_x = blockDim.x * gridDim.x;
    while (tx < out_len) {
128
129
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
130
      DType val = BinaryOp::Call(
131
          lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
132
133
134
135
136
137
138
139
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

140
/**
141
142
 * @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree
 * reduction.
143
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
144
 *       is responsible for the computation on different edges. Threadblocks
145
146
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension.
147
 */
148
149
150
template <
    typename Idx, typename DType, bool UseBcast = false, bool UseIdx = false,
    int LhsTarget = 0, int RhsTarget = 2>
151
__global__ void SDDMMCooTreeReduceKernel(
152
153
154
155
156
157
    const DType* __restrict__ lhs, const DType* __restrict__ rhs,
    DType* __restrict__ out, const Idx* __restrict__ row,
    const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
    int64_t M, int64_t E, int64_t reduce_size,
    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
158
159
160
161
162
  Idx ty = blockIdx.x * blockDim.y + threadIdx.y;
  if (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
163
164
165
166
    const DType* lhsoff =
        lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
    const DType* rhsoff =
        rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
167
168
    DType* outoff = out + eid * out_len;
    int tx = threadIdx.x;  // tx < 32
169
170
    for (int i = blockIdx.y; i < out_len;
         i += gridDim.y) {  // over output feature dimension
171
172
      const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
      const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
173
      DType val = reduce::Sum<Idx, DType>::zero();
Zihao Ye's avatar
Zihao Ye committed
174
      for (int j = tx; j < reduce_size; j += 64) {
175
176
        val += lhsoff[lhs_add * reduce_size + j] *
               rhsoff[rhs_add * reduce_size + j];
Zihao Ye's avatar
Zihao Ye committed
177
        if (j + 32 < reduce_size)
178
179
          val += lhsoff[lhs_add * reduce_size + j + 32] *
                 rhsoff[rhs_add * reduce_size + j + 32];
Zihao Ye's avatar
Zihao Ye committed
180
      }
181
182
#pragma unroll
      for (int offset = 16; offset > 0; offset /= 2)
sangwzh's avatar
sangwzh committed
183
        val += __shfl_down(val, offset);
184
      if (tx == 0) outoff[i] = val;
185
186
187
188
    }
  }
}

189
190
// Binary search the row_offsets to find the source node of the edge id.
template <typename Idx>
191
192
__device__ __forceinline__ Idx
BinarySearchSrc(const Idx* array, Idx length, Idx eid) {
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
  Idx lo = 0, hi = length - 1;
  while (lo < hi) {
    Idx mid = (lo + hi) >> 1;
    if (_ldg(array + mid) <= eid) {
      lo = mid + 1;
    } else {
      hi = mid;
    }
  }
  // INVARIANT: lo == hi
  if (_ldg(array + hi) == eid) {
    return hi;
  } else {
    return hi - 1;
  }
}

210
/**
211
212
 * @brief CUDA kernel of g-SDDMM on Csr format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
213
 *       is responsible for the computation on different edges. Threadblocks
214
215
216
217
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension. To efficiently find the source node idx and
 * destination node index of an given edge on Csr format, it uses binary search
 * (time complexity O(log N)).
218
 */
219
220
221
template <
    typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
    bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
222
__global__ void SDDMMCsrKernel(
223
224
225
226
227
228
    const DType* __restrict__ lhs, const DType* __restrict__ rhs,
    DType* __restrict__ out, const Idx* __restrict__ indptr,
    const Idx* __restrict__ indices, const Idx* __restrict__ edge_map,
    int64_t N, int64_t M, int64_t E, int64_t reduce_size,
    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
229
230
231
232
233
234
235
236
237
  // SDDMM with Csr.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = BinarySearchSrc<Idx>(indptr, N + 1, ty);
    const Idx dst = _ldg(indices + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
238
239
240
241
242
243
244
245
    const DType* lhsoff =
        BinaryOp::use_lhs
            ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
            : nullptr;
    const DType* rhsoff =
        BinaryOp::use_rhs
            ? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
            : nullptr;
246
247
    DType* outoff = out + eid * out_len;
    while (tx < out_len) {
248
249
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
250
      DType val = BinaryOp::Call(
251
          lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
252
253
254
255
256
257
258
259
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

260
/**
261
262
263
264
265
266
 * @brief CUDA implementation of g-SDDMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param lhs The left hand side operand feature.
 * @param rhs The right hand size operand feature.
 * @param out The result feature on edges.
267
 */
268
269
270
template <
    typename Idx, typename DType, typename Op, int LhsTarget = 0,
    int RhsTarget = 2>
271
void SDDMMCoo(
272
    const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
273
    NDArray out) {
274
275
276
277
278
279
  const Idx* row = coo.row.Ptr<Idx>();
  const Idx* col = coo.col.Ptr<Idx>();
  const Idx* edge_map = coo.data.Ptr<Idx>();
  const DType* lhs_data = lhs.Ptr<DType>();
  const DType* rhs_data = rhs.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
sangwzh's avatar
sangwzh committed
280
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
281

282
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
283
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
284
285
286
287
288
  int64_t reduce_dim = bcast.reduce_size;

  const int64_t nnz = coo.row->shape[0];
  const bool use_idx = !IsNullArray(coo.data);

289
290
291
292
293
294
295
296
  if (std::is_same<Op, binary::Dot<DType> >::value && reduce_dim >= 32) {
    const int ntx = 32;  // on feature dimension
    const int nty = 8;   // on out dimension
    const int nbx = (nnz + nty - 1) / nty;
    const int nby = FindNumBlocks<'y'>(len);
    const dim3 nblks(nbx, nby);
    const dim3 nthrs(ntx, nty);
    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
297
      CUDA_KERNEL_CALL(
298
299
300
301
302
          (SDDMMCooTreeReduceKernel<
              Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
          nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
          edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
          rhs_off, lhs_len, rhs_len, len);
303
    });
304
305
306
307
308
309
310
311
  } else {
    const int ntx = FindNumThreads(len);
    const int nty = CUDA_MAX_NUM_THREADS / ntx;
    const int nbx = (len + ntx - 1) / ntx;
    const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);
    const dim3 nblks(nbx, nby);
    const dim3 nthrs(ntx, nty);
    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
312
313
314
315
316
317
      CUDA_KERNEL_CALL(
          (SDDMMCooKernel<
              Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
          nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
          edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
          rhs_off, lhs_len, rhs_len, len);
318
319
    });
  }
320
321
}

322
/**
323
324
325
326
327
328
 * @brief CUDA implementation of g-SDDMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param lhs The left hand side operand feature.
 * @param rhs The right hand size operand feature.
 * @param out The result feature on edges.
329
 */
330
331
332
template <
    typename Idx, typename DType, typename Op, int LhsTarget = 0,
    int RhsTarget = 2>
333
void SDDMMCsr(
334
    const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
335
    NDArray out) {
336
337
338
339
340
341
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const Idx* edge_map = csr.data.Ptr<Idx>();
  const DType* lhs_data = lhs.Ptr<DType>();
  const DType* rhs_data = rhs.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
sangwzh's avatar
sangwzh committed
342
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
343
344
  int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];

345
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
346
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
347
348
349
350
351
352
353
354
355
356
  int64_t reduce_dim = bcast.reduce_size;

  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

357
  BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
358
359
360
361
362
    CUDA_KERNEL_CALL(
        (SDDMMCsrKernel<
            Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
        nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, indptr, indices,
        edge_map, N, M, E, reduce_dim, lhs_off, rhs_off, lhs_len, rhs_len, len);
363
364
365
366
367
368
369
  });
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

370
#endif  // DGL_ARRAY_CUDA_SDDMM_CUH_