spmm.cuh 33.8 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/spmm.cuh
 * @brief SPMM CUDA kernel function header.
7
8
9
10
11
 */
#ifndef DGL_ARRAY_CUDA_SPMM_CUH_
#define DGL_ARRAY_CUDA_SPMM_CUH_

#include <dgl/bcast.h>
12

13
#include <limits>
14

15
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
16
#include "utils.h"
17
18
19
20
#include "atomic.cuh"
#include "bf16.cuh"
#include "fp16.cuh"
#include "macro.cuh"
21
22
23
24
25
26

namespace dgl {

using namespace cuda;

namespace aten {
27

28
/**
29
 * @brief Determine whether cusparse SpMM function is applicable.
30
31
32
 */
template <typename DType, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
sangwzh's avatar
sangwzh committed
33
#if DTKRT_VERSION < 11000
34
35
36
37
38
  if (std::is_same<IdType, int>::value &&
      (std::is_same<DType, float>::value || std::is_same<DType, double>::value))
    return true;
  return false;
#else
39
  if (std::is_same<DType, __half>::value ||
sangwzh's avatar
sangwzh committed
40
      std::is_same<DType, __hip_bfloat16>::value)
41
    return false;  // cusparse's SpMM on fp16 is slow, temporally disabled.
42
43
  // If the CSR matrix has more NNZ than matrix size, we should not use
  // cuSPARSE 11.1.
44
45
46
47
  return !more_nnz_than_matrix_size;
#endif
}

48
49
namespace {

50
/** @brief Call cuBLAS geam API for transpose operation for float and double. */
51
template <typename DType>
sangwzh's avatar
sangwzh committed
52
53
hipblasStatus_t Xgeam(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
54
55
    int m, int n, const DType* alpha, const DType* A, int lda,
    const DType* beta, const DType* B, int ldb, DType* C, int ldc) {
56
  LOG(FATAL) << "Not supported dtype";
sangwzh's avatar
sangwzh committed
57
  return HIPBLAS_STATUS_EXECUTION_FAILED;
58
59
}

60
template <>
sangwzh's avatar
sangwzh committed
61
62
hipblasStatus_t Xgeam<__half>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
63
64
    int m, int n, const __half* alpha, const __half* A, int lda,
    const __half* beta, const __half* B, int ldb, __half* C, int ldc) {
65
66
67
  // TODO(ndickson): There is no cublasHgeam, so a different
  // implementation would be required.
  LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
sangwzh's avatar
sangwzh committed
68
  return HIPBLAS_STATUS_EXECUTION_FAILED;
69
}
70
71
72

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
73
74
75
76
77
hipblasStatus_t Xgeam<__hip_bfloat16>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
    int m, int n, const __hip_bfloat16* alpha, const __hip_bfloat16* A, int lda,
    const __hip_bfloat16* beta, const __hip_bfloat16* B, int ldb,
    __hip_bfloat16* C, int ldc) {
78
79
80
  // TODO(ndickson): There is no cublasHgeam, so a different
  // implementation would be required.
  LOG(FATAL) << "Xgeam does not support dtype bfloat16 (BF16)";
sangwzh's avatar
sangwzh committed
81
  return HIPBLAS_STATUS_EXECUTION_FAILED;
82
83
}
#endif  // BF16_ENABLED
84

85
template <>
sangwzh's avatar
sangwzh committed
86
87
hipblasStatus_t Xgeam<float>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
88
89
    int m, int n, const float* alpha, const float* A, int lda,
    const float* beta, const float* B, int ldb, float* C, int ldc) {
sangwzh's avatar
sangwzh committed
90
  return hipblasSgeam(
91
      handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
92
93
94
}

template <>
sangwzh's avatar
sangwzh committed
95
96
hipblasStatus_t Xgeam<double>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
97
98
    int m, int n, const double* alpha, const double* A, int lda,
    const double* beta, const double* B, int ldb, double* C, int ldc) {
sangwzh's avatar
sangwzh committed
99
  return hipblasDgeam(
100
      handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
101
102
}

103
104
/**
 * @brief Transpose operator kernel implementation.
105
 * @note not efficient but it's not a bottleneck, used for float16 dtype.
106
107
108
 */
template <typename DType>
__global__ void _TransposeKernel(
109
    const DType* __restrict__ in, DType* __restrict__ out, int n, int m) {
110
111
112
113
114
  int i = blockIdx.x;
  for (int j = threadIdx.x; j < m; j += blockDim.x)
    out[i * m + j] = in[j * n + i];
}

115
/**
116
117
118
 * @brief Tranpose the input matrix.
 * @param row number of rows of input matrix.
 * @param col number of columns of input matrix.
119
120
 */
template <typename DType>
121
void _Transpose(const DType* in, DType* out, int row, int col) {
122
123
  DType alpha = 1., beta = 0.;
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
124
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
125
  if (!thr_entry->cublas_handle)
sangwzh's avatar
sangwzh committed
126
127
    CUBLAS_CALL(hipblasCreate(&(thr_entry->cublas_handle)));
  CUBLAS_CALL(hipblasSetStream(thr_entry->cublas_handle, stream));
128
  CUBLAS_CALL(Xgeam<DType>(
sangwzh's avatar
sangwzh committed
129
      thr_entry->cublas_handle, HIPBLAS_OP_T, HIPBLAS_OP_N, row, col, &alpha, in,
130
      col, &beta, nullptr, row, out, row));
131
132
}

133
/**
134
135
 * @brief Tranpose the input matrix for data type half.
 * @note cuBLAS has no geam API for half data type, fallback to our kernel.
136
137
 */
template <>
138
void _Transpose<__half>(const __half* in, __half* out, int row, int col) {
sangwzh's avatar
sangwzh committed
139
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
140
141
  int nt = FindNumThreads(row);
  int nb = col;
142
  CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
143
144
}

145
#if BF16_ENABLED
146
/**
147
148
 * @brief Tranpose the input matrix for data type half.
 * @note cuBLAS has no geam API for bf16 data type, fallback to our kernel.
149
150
 */
template <>
sangwzh's avatar
sangwzh committed
151
152
153
void _Transpose<__hip_bfloat16>(
    const __hip_bfloat16* in, __hip_bfloat16* out, int row, int col) {
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
154
155
156
157
158
159
  int nt = FindNumThreads(row);
  int nb = col;
  CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
}
#endif  // BF16_ENABLED

sangwzh's avatar
sangwzh committed
160
#if DTKRT_VERSION < 11000
161
template <typename DType>
sangwzh's avatar
sangwzh committed
162
163
164
165
hipsparseStatus_t Xcsrmm2(
    hipsparseHandle_t handle, hipsparseOperation_t transA,
    hipsparseOperation_t transB, int m, int n, int k, int nnz,
    const DType* alpha, const hipsparseMatDescr_t descrA, const DType* csrValA,
166
167
    const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb,
    const DType* beta, DType* C, int ldc) {
168
  LOG(INFO) << "Not supported dtype";
sangwzh's avatar
sangwzh committed
169
  return HIPSPARSE_STATUS_EXECUTION_FAILED;
170
171
172
}

template <>
sangwzh's avatar
sangwzh committed
173
174
175
176
hipsparseStatus_t Xcsrmm2<float>(
    hipsparseHandle_t handle, hipsparseOperation_t transA,
    hipsparseOperation_t transB, int m, int n, int k, int nnz,
    const float* alpha, const hipsparseMatDescr_t descrA, const float* csrValA,
177
178
    const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb,
    const float* beta, float* C, int ldc) {
sangwzh's avatar
sangwzh committed
179
  return hipsparseScsrmm2(
180
181
      handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
      csrColIndA, B, ldb, beta, C, ldc);
182
183
184
}

template <>
sangwzh's avatar
sangwzh committed
185
186
187
188
hipsparseStatus_t Xcsrmm2<double>(
    hipsparseHandle_t handle, hipsparseOperation_t transA,
    hipsparseOperation_t transB, int m, int n, int k, int nnz,
    const double* alpha, const hipsparseMatDescr_t descrA, const double* csrValA,
189
190
    const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb,
    const double* beta, double* C, int ldc) {
sangwzh's avatar
sangwzh committed
191
  return hipsparseDcsrmm2(
192
193
      handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
      csrColIndA, B, ldb, beta, C, ldc);
194
195
196
}
#endif

197
/** Cusparse implementation of SpMM on Csr format. */
198
199
template <typename DType, typename IdType>
void CusparseCsrmm2(
200
    const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
201
202
    const DType* A_data, DType* C_data, int x_length,
    bool use_deterministic_alg_only = false) {
203
  // We use csrmm2 to perform following operation:
204
205
206
207
208
209
  // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
  // for node feature tensor. However, since cusparse only supports
  // column-major, while our tensor is stored in row-major, the actual
  // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
  // implement transposition and allocate intermediate workspace memory for
  // this.
210
211
212
213
214
215
216
217
218
  const int m = csr.num_rows;
  const int n = x_length;
  const int k = csr.num_cols;
  const int nnz = csr.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
  // device
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
219
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
220
221
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
222
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
223
  }
sangwzh's avatar
sangwzh committed
224
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
225
226
227
  // all one data array
  DType* valptr = nullptr;
  if (!A_data) {
228
229
    valptr =
        static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
230
231
    _Fill(valptr, nnz, static_cast<DType>(1.));
  }
sangwzh's avatar
sangwzh committed
232
233
234
#if DTKRT_VERSION >= 11000
  hipsparseSpMatDescr_t matA;
  hipsparseDnMatDescr_t matB, matC;
235
236
  constexpr auto dtype = cuda_dtype<DType>::value;
  constexpr auto idtype = cusparse_idtype<IdType>::value;
sangwzh's avatar
sangwzh committed
237
  CUSPARSE_CALL(hipsparseCreateCsr(
238
      &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
239
      static_cast<IdType*>(csr.indices->data),
240
      const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
sangwzh's avatar
sangwzh committed
241
242
243
      HIPSPARSE_INDEX_BASE_ZERO, dtype));
  CUSPARSE_CALL(hipsparseCreateDnMat(
      &matB, k, n, n, const_cast<DType*>(B_data), dtype, HIPSPARSE_ORDER_ROW));
244
  CUSPARSE_CALL(
sangwzh's avatar
sangwzh committed
245
      hipsparseCreateDnMat(&matC, m, n, n, C_data, dtype, HIPSPARSE_ORDER_ROW));
246

sangwzh's avatar
sangwzh committed
247
248
  auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE;
249
  size_t workspace_size;
250

sangwz's avatar
sangwz committed
251
252
253
  hipsparseSpMMAlg_t spmm_alg = use_deterministic_alg_only
                                   ? HIPSPARSE_SPMM_CSR_ALG3
                                   : HIPSPARSE_SPMM_CSR_ALG2;
sangwzh's avatar
sangwzh committed
254
  CUSPARSE_CALL(hipsparseSpMM_bufferSize(
255
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwz's avatar
sangwz committed
256
      matC, dtype, spmm_alg, &workspace_size));
257

258
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
sangwzh's avatar
sangwzh committed
259
  CUSPARSE_CALL(hipsparseSpMM(
260
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
261

sangwz's avatar
sangwz committed
262
      matC, dtype, spmm_alg, workspace));
263

264
265
  device->FreeWorkspace(ctx, workspace);

sangwzh's avatar
sangwzh committed
266
267
268
  CUSPARSE_CALL(hipsparseDestroySpMat(matA));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matB));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matC));
269
270
#else
  // allocate matrix for temporary transposed output
271
272
  DType* trans_out =
      static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
273

sangwzh's avatar
sangwzh committed
274
275
276
277
  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
278
  CUSPARSE_CALL(Xcsrmm2<DType>(
sangwzh's avatar
sangwzh committed
279
280
      thr_entry->cusparse_handle, HIPSPARSE_OPERATION_NON_TRANSPOSE,
      HIPSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
281
282
283
      (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
      static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, trans_out,
      m));
sangwzh's avatar
sangwzh committed
284
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
285
286
287
288
  // transpose the output matrix
  _Transpose(trans_out, C_data, n, m);
  device->FreeWorkspace(ctx, trans_out);
#endif
289
  if (valptr) device->FreeWorkspace(ctx, valptr);
290
291
}

292
/** Cusparse implementation of SpMM on Csr format. */
293
294
template <typename DType, typename IdType>
void CusparseCsrmm2Hetero(
295
    const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
296

sangwz's avatar
sangwz committed
297
    const DType* A_data, DType* C_data, int64_t x_length, hipStream_t strm_id,
298
    bool use_deterministic_alg_only = false) {
299

300
  // We use csrmm2 to perform following operation:
301
302
303
304
305
306
  // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
  // for node feature tensor. However, since cusparse only supports
  // column-major, while our tensor is stored in row-major, the actual
  // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
  // implement transposition and allocate intermediate workspace memory for
  // this.
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
  int int_maxlimit = std::numeric_limits<int>::max();
  CHECK_GE(int_maxlimit, (csr.num_rows));
  CHECK_GE(int_maxlimit, csr.num_cols);
  CHECK_GE(int_maxlimit, csr.indices->shape[0]);
  const int m = csr.num_rows;
  const int n = x_length;
  const int k = csr.num_cols;
  const int nnz = csr.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 1.0;
  // device
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
322
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
323
  }
sangwzh's avatar
sangwzh committed
324
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, strm_id));
325
326
327
  // all one data array
  DType* valptr = nullptr;
  if (!A_data) {
328
329
    valptr =
        static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
330
331
    _Fill(valptr, nnz, static_cast<DType>(1.));
  }
sangwzh's avatar
sangwzh committed
332
333
334
#if DTKRT_VERSION >= 11000
  hipsparseSpMatDescr_t matA;
  hipsparseDnMatDescr_t matB, matC;
335
336
  constexpr auto dtype = cuda_dtype<DType>::value;
  constexpr auto idtype = cusparse_idtype<IdType>::value;
sangwzh's avatar
sangwzh committed
337
  CUSPARSE_CALL(hipsparseCreateCsr(
338
      &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
339
      static_cast<IdType*>(csr.indices->data),
340
      const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
sangwzh's avatar
sangwzh committed
341
342
343
      HIPSPARSE_INDEX_BASE_ZERO, dtype));
  CUSPARSE_CALL(hipsparseCreateDnMat(
      &matB, k, n, n, const_cast<DType*>(B_data), dtype, HIPSPARSE_ORDER_ROW));
344
  CUSPARSE_CALL(
sangwzh's avatar
sangwzh committed
345
      hipsparseCreateDnMat(&matC, m, n, n, C_data, dtype, HIPSPARSE_ORDER_ROW));
346

sangwzh's avatar
sangwzh committed
347
348
  auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE;
349
  size_t workspace_size;
350

sangwz's avatar
sangwz committed
351
352
353
  hipsparseSpMMAlg_t spmm_alg = use_deterministic_alg_only
                                   ? HIPSPARSE_SPMM_CSR_ALG3
                                   : HIPSPARSE_SPMM_CSR_ALG2;
sangwzh's avatar
sangwzh committed
354
  CUSPARSE_CALL(hipsparseSpMM_bufferSize(
355
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwz's avatar
sangwz committed
356
      matC, dtype, spmm_alg, &workspace_size));
357
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
sangwzh's avatar
sangwzh committed
358
  CUSPARSE_CALL(hipsparseSpMM(
359
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwz's avatar
sangwz committed
360
      matC, dtype, spmm_alg, workspace));
361

362
363
  device->FreeWorkspace(ctx, workspace);

sangwzh's avatar
sangwzh committed
364
365
366
  CUSPARSE_CALL(hipsparseDestroySpMat(matA));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matB));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matC));
367
#else
sangwzh's avatar
sangwzh committed
368
369
370
371
  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
372
373
  CHECK_EQ(sizeof(IdType), sizeof(int32_t));
  CUSPARSE_CALL(Xcsrmm2<DType>(
sangwzh's avatar
sangwzh committed
374
375
      thr_entry->cusparse_handle, HIPSPARSE_OPERATION_NON_TRANSPOSE,
      HIPSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
376
377
      (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
      static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, C_data, m));
sangwzh's avatar
sangwzh committed
378
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
379
#endif
380
  if (valptr) device->FreeWorkspace(ctx, valptr);
381
382
383
384
}

}  // namespace

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
#define SWITCH_OP(op, Op, ...)                                  \
  do {                                                          \
    if ((op) == "add") {                                        \
      typedef cuda::binary::Add<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "sub") {                                 \
      typedef cuda::binary::Sub<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "mul") {                                 \
      typedef cuda::binary::Mul<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "div") {                                 \
      typedef cuda::binary::Div<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "copy_lhs") {                            \
      typedef cuda::binary::CopyLhs<DType> Op;                  \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "copy_rhs") {                            \
      typedef cuda::binary::CopyRhs<DType> Op;                  \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
      LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
    }                                                           \
408
409
  } while (0)

410
411
namespace cuda {

412
/**
413
414
 * @brief CUDA kernel of g-SpMM on Coo format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
415
 *       is responsible for the computation on different edges. Threadblocks
416
417
418
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension. To avoid possible data hazards, it uses
 * atomic operators for reduction.
419
 */
420
421
422
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
423
__global__ void SpMMCooKernel(
424
425
426
427
428
429
430
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    const Idx* __restrict__ row, const Idx* __restrict__ col,
    const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len) {
431
432
433
434
435
436
437
438
439
  // SPMM with COO.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
440
441
    const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
    const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    DType* outoff = out + dst * out_len;
    while (tx < out_len) {
      const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
      const int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
      DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
      Idx* arguoff = nullptr;  // arguoff is not used in SpMMCoo.
      Idx* argeoff = nullptr;  // argeoff is not used in SpMMCoo.
      ReduceOp::Call(outoff + tx, arguoff, argeoff, val, src, eid);
      tx += stride_x;
    }
    ty += stride_y;
  }
}

456
/**
457
458
 * @brief CUDA kernel to compute argu and arge in g-SpMM on Coo format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
459
 *       is responsible for the computation on different edges. Threadblocks
460
461
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension.
462
 */
463
464
465
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
466
__global__ void ArgSpMMCooKernel(
467
468
469
470
471
472
473
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    const Idx* __restrict__ row, const Idx* __restrict__ col,
    const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len) {
474
475
476
477
478
479
480
481
482
  // SPMM with COO arg max/min.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
483
484
    const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
    const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
485
    const DType* outoff = out + dst * out_len;
486
487
    Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len) : nullptr;
    Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len) : nullptr;
488
489
490
491
492
493
494
495
496
497
498
    while (tx < out_len) {
      int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
      int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
      DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
      ReduceOp::CallArg(tx, arguoff, argeoff, val, outoff[tx], src, eid);
      tx += stride_x;
    }
    ty += stride_y;
  }
}

499
/**
500
501
 * @brief CUDA kernel of g-SpMM on Csr format.
 * @note it uses node parallel strategy, different threadblocks (on y-axis)
502
 *       is responsible for the computation on different destination nodes.
503
504
505
 *       Threadblocks on the x-axis are responsible for the computation on
 *       different positions in feature dimension.
 */
506
507
508
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
509
__global__ void SpMMCsrKernel(
510
511
512
513
514
515
516
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
    const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len) {
517
  // SPMM with CSR.
518
519
520
  int ty = blockIdx.x * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.x;
  const int stride_x = blockDim.x * gridDim.y;
521
  while (ty < num_rows) {
522
    int tx = blockIdx.y * blockDim.x + threadIdx.x;
523
    while (tx < out_len) {
524
      typename accum_dtype<DType>::type local_accum = ReduceOp::zero();
525
526
527
528
529
530
      Idx local_argu = 0, local_arge = 0;
      const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
      const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
      for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
        const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
        const Idx cid = _ldg(indices + i);
531
532
533
534
        const DType* uoff =
            BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
        const DType* eoff =
            BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
535
536
537
        DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
        ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
      }
538
539
      // The use of += is to compute cross-type reducing on heterogeneous graph
      // when reduce op is `sum`.
540
      //     C = SpMM(SpA, B) + C
541
542
543
      // Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and
      // min-reducer. It does not affect the output on homogeneous graph as
      // `out` is initialized to zero.
544
      out[ty * out_len + tx] += static_cast<DType>(local_accum);
545
546
547
548
549
550
551
552
553
554
      if (ReduceOp::require_arg && BinaryOp::use_lhs)
        arg_u[ty * out_len + tx] = local_argu;
      if (ReduceOp::require_arg && BinaryOp::use_rhs)
        arg_e[ty * out_len + tx] = local_arge;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

555
/**
556
557
 * @brief CUDA kernel of SpMM-Min/Max on Csr format.
 * @note it uses node parallel strategy, different threadblocks (on y-axis)
558
559
560
561
 *       is responsible for the computation on different destination nodes.
 *       Threadblocks on the x-axis are responsible for the computation on
 *       different positions in feature dimension.
 */
562
563
564
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
565
__global__ void SpMMCmpCsrHeteroKernel(
566
567
568
569
570
571
572
573
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype,
    const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
    const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len, const int src_type, const int etype) {
574
575
576
577
578
579
580
  // SPMM with CSR.
  int ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  const int stride_x = blockDim.x * gridDim.x;
  while (ty < num_rows) {
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    while (tx < out_len) {
581
      using accum_type = typename accum_dtype<DType>::type;
582
583
      accum_type local_accum =
          static_cast<accum_type>(out[ty * out_len + tx]);  // ReduceOp::zero();
584
585
586
587
588
589
      Idx local_argu = 0, local_arge = 0;
      const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
      const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
      for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
        const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
        const Idx cid = _ldg(indices + i);
590
591
592
593
        const DType* uoff =
            BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
        const DType* eoff =
            BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
594
        DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
595
596
        ReduceOp::Call(
            &local_accum, &local_argu, &local_arge, tmp_out, cid, eid);
597
      }
598
599
      // Update output only when max/min values are different that original
      // output
600
      DType new_out = static_cast<DType>(local_accum);
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
      if (out[ty * out_len + tx] != new_out) {
        out[ty * out_len + tx] = new_out;
        if (ReduceOp::require_arg && BinaryOp::use_lhs) {
          arg_u[ty * out_len + tx] = local_argu;
          arg_u_ntype[ty * out_len + tx] = src_type;
        }
        if (ReduceOp::require_arg && BinaryOp::use_rhs) {
          arg_e[ty * out_len + tx] = local_arge;
          arg_e_etype[ty * out_len + tx] = etype;
        }
      }
      tx += stride_x;
    }
    ty += stride_y;
  }
}

618
/**
619
620
621
622
623
624
625
 * @brief CUDA implementation of g-SpMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
626
 *        correspond to the minimum/maximum values of reduction result on
627
628
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
629
 * @param arge Arg-Min/Max on edges. which refers the source node indices
630
 *        correspond to the minimum/maximum values of reduction result on
631
632
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
633
 */
634
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
635
void SpMMCoo(
636
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
637
    NDArray out, NDArray argu, NDArray arge) {
638
639
  /**
   * TODO(Xin): Disable half precision for SpMMCoo due to the round-off error.
640
   * We should use fp32 for the accumulation but it's hard to modify the
641
642
643
644
   * current implementation.
   */
#if BF16_ENABLED
  if (std::is_same<DType, __half>::value ||
sangwzh's avatar
sangwzh committed
645
      std::is_same<DType, __hip_bfloat16>::value)
646
647
648
649
650
651
#else
  if (std::is_same<DType, __half>::value)
#endif  // BF16_ENABLED
    LOG(FATAL) << "SpMMCoo doesn't support half precision fow now. "
               << "Please use SpMMCsr instead by allowing the graph "
               << "materialize CSR/CSC formats.";
652
  const Idx *row = coo.row.Ptr<Idx>(), *col = coo.col.Ptr<Idx>(),
653
654
655
            *edge_map = coo.data.Ptr<Idx>();
  const DType *ufeat_data = ufeat.Ptr<DType>(),
              *efeat_data = efeat.Ptr<DType>();
656
657
  DType* out_data = out.Ptr<DType>();
  Idx *argu_data = argu.Ptr<Idx>(), *arge_data = arge.Ptr<Idx>();
sangwzh's avatar
sangwzh committed
658
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
659
660
661
  const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];

  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
662
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
663
664
665
666

  int64_t out_size = out.NumElements();
  const int nt = FindNumThreads(out_size);
  const int nb = (out_size + nt - 1) / nt;
667
668
  CUDA_KERNEL_CALL(
      _FillKernel, nb, nt, 0, stream, out_data, out_size, ReduceOp::zero());
669
670
671
672
673
674
675
676
677
678

  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(coo.data);

  BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
679
680
681
682
683
    CUDA_KERNEL_CALL(
        (SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
        nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
        arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off, lhs_len,
        rhs_len, len);
684
    if (ReduceOp::require_arg) {
685
686
687
688
      CUDA_KERNEL_CALL(
          (ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
          arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off,
689
          lhs_len, rhs_len, len);
690
691
692
693
    }
  });
}

694
/**
695
696
697
698
699
700
701
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
702
 *        correspond to the minimum/maximum values of reduction result on
703
704
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
705
 * @param arge Arg-Min/Max on edges. which refers the source node indices
706
 *        correspond to the minimum/maximum values of reduction result on
707
708
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
709
 */
710
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
711
void SpMMCsr(
712
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
713
    NDArray out, NDArray argu, NDArray arge) {
714
715
716
717
718
719
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const Idx* edge_map = csr.data.Ptr<Idx>();
  const DType* ufeat_data = ufeat.Ptr<DType>();
  const DType* efeat_data = efeat.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
720
721
722
  Idx* argu_data = argu.Ptr<Idx>();
  Idx* arge_data = arge.Ptr<Idx>();

sangwzh's avatar
sangwzh committed
723
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
724
725

  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
726
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
727
728
  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
729
  const int nby = (len + ntx - 1) / ntx;
730
  const int nbx = FindNumBlocks<'x'>((csr.num_rows + nty - 1) / nty);
731
732
733
734
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

735
736
737
738
739
740
741
  BCAST_IDX_CTX_SWITCH(
      bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
      {CUDA_KERNEL_CALL(
          (SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
          arge_data, indptr, indices, edge_map, csr.num_rows, csr.num_cols,
          ubcast_off, ebcast_off, lhs_len, rhs_len, len)});
742
743
}

744
/**
745
746
747
748
749
750
751
 * @brief CUDA kernel of SpMM-Min/Max on Csr format on heterogeneous graph.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
752
 *        correspond to the minimum/maximum values of reduction result on
753
754
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
755
 * @param arge Arg-Min/Max on edges. which refers the source node indices
756
 *        correspond to the minimum/maximum values of reduction result on
757
758
759
760
761
762
763
764
765
766
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
 * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
 * the source node types correspond to the minimum/maximum values of reduction
 * result on destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
 * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
 * source node indices correspond to the minimum/maximum values of reduction
 * result on destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
767
768
 * @param src_type Node type of the source nodes of an etype
 * @param etype Edge type
769
 */
770
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
771
void SpMMCmpCsrHetero(
772
773
774
775
776
777
778
779
780
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
    NDArray arge_etype, const int src_type, const int etype) {
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const Idx* edge_map = csr.data.Ptr<Idx>();
  const DType* ufeat_data = ufeat.Ptr<DType>();
  const DType* efeat_data = efeat.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
781
782
783
  Idx* argu_data = argu.Ptr<Idx>();
  Idx* arge_data = arge.Ptr<Idx>();

sangwzh's avatar
sangwzh committed
784
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
785
786

  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
787
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
788
789
790
791
792
793
794
795
  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

796
797
798
799
800
801
802
803
804
805
  BCAST_IDX_CTX_SWITCH(
      bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
      {CUDA_KERNEL_CALL(
          (SpMMCmpCsrHeteroKernel<
              Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
          arge_data, static_cast<Idx*>(argu_ntype->data),
          static_cast<Idx*>(arge_etype->data), indptr, indices, edge_map,
          csr.num_rows, csr.num_cols, ubcast_off, ebcast_off, lhs_len, rhs_len,
          len, src_type, etype)});
806
807
}

808
809
810
811
}  // namespace cuda
}  // namespace aten
}  // namespace dgl

812
#endif  // DGL_ARRAY_CUDA_SPMM_CUH_