spmm.cuh 33.5 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/spmm.cuh
 * @brief SPMM CUDA kernel function header.
7
8
9
10
11
 */
#ifndef DGL_ARRAY_CUDA_SPMM_CUH_
#define DGL_ARRAY_CUDA_SPMM_CUH_

#include <dgl/bcast.h>
12

13
#include <limits>
14

15
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
16
#include "utils.h"
17
18
19
20
#include "atomic.cuh"
#include "bf16.cuh"
#include "fp16.cuh"
#include "macro.cuh"
21
22
23
24
25
26

namespace dgl {

using namespace cuda;

namespace aten {
27

28
/**
29
 * @brief Determine whether cusparse SpMM function is applicable.
30
31
32
 */
template <typename DType, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
sangwzh's avatar
sangwzh committed
33
#if DTKRT_VERSION < 11000
34
35
36
37
38
  if (std::is_same<IdType, int>::value &&
      (std::is_same<DType, float>::value || std::is_same<DType, double>::value))
    return true;
  return false;
#else
39
  if (std::is_same<DType, __half>::value ||
sangwzh's avatar
sangwzh committed
40
      std::is_same<DType, __hip_bfloat16>::value)
41
    return false;  // cusparse's SpMM on fp16 is slow, temporally disabled.
42
43
  // If the CSR matrix has more NNZ than matrix size, we should not use
  // cuSPARSE 11.1.
44
45
46
47
  return !more_nnz_than_matrix_size;
#endif
}

48
49
namespace {

50
/** @brief Call cuBLAS geam API for transpose operation for float and double. */
51
template <typename DType>
sangwzh's avatar
sangwzh committed
52
53
hipblasStatus_t Xgeam(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
54
55
    int m, int n, const DType* alpha, const DType* A, int lda,
    const DType* beta, const DType* B, int ldb, DType* C, int ldc) {
56
  LOG(FATAL) << "Not supported dtype";
sangwzh's avatar
sangwzh committed
57
  return HIPBLAS_STATUS_EXECUTION_FAILED;
58
59
}

60
template <>
sangwzh's avatar
sangwzh committed
61
62
hipblasStatus_t Xgeam<__half>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
63
64
    int m, int n, const __half* alpha, const __half* A, int lda,
    const __half* beta, const __half* B, int ldb, __half* C, int ldc) {
65
66
67
  // TODO(ndickson): There is no cublasHgeam, so a different
  // implementation would be required.
  LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
sangwzh's avatar
sangwzh committed
68
  return HIPBLAS_STATUS_EXECUTION_FAILED;
69
}
70
71
72

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
73
74
75
76
77
hipblasStatus_t Xgeam<__hip_bfloat16>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
    int m, int n, const __hip_bfloat16* alpha, const __hip_bfloat16* A, int lda,
    const __hip_bfloat16* beta, const __hip_bfloat16* B, int ldb,
    __hip_bfloat16* C, int ldc) {
78
79
80
  // TODO(ndickson): There is no cublasHgeam, so a different
  // implementation would be required.
  LOG(FATAL) << "Xgeam does not support dtype bfloat16 (BF16)";
sangwzh's avatar
sangwzh committed
81
  return HIPBLAS_STATUS_EXECUTION_FAILED;
82
83
}
#endif  // BF16_ENABLED
84

85
template <>
sangwzh's avatar
sangwzh committed
86
87
hipblasStatus_t Xgeam<float>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
88
89
    int m, int n, const float* alpha, const float* A, int lda,
    const float* beta, const float* B, int ldb, float* C, int ldc) {
sangwzh's avatar
sangwzh committed
90
  return hipblasSgeam(
91
      handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
92
93
94
}

template <>
sangwzh's avatar
sangwzh committed
95
96
hipblasStatus_t Xgeam<double>(
    hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb,
97
98
    int m, int n, const double* alpha, const double* A, int lda,
    const double* beta, const double* B, int ldb, double* C, int ldc) {
sangwzh's avatar
sangwzh committed
99
  return hipblasDgeam(
100
      handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
101
102
}

103
104
/**
 * @brief Transpose operator kernel implementation.
105
 * @note not efficient but it's not a bottleneck, used for float16 dtype.
106
107
108
 */
template <typename DType>
__global__ void _TransposeKernel(
109
    const DType* __restrict__ in, DType* __restrict__ out, int n, int m) {
110
111
112
113
114
  int i = blockIdx.x;
  for (int j = threadIdx.x; j < m; j += blockDim.x)
    out[i * m + j] = in[j * n + i];
}

115
/**
116
117
118
 * @brief Tranpose the input matrix.
 * @param row number of rows of input matrix.
 * @param col number of columns of input matrix.
119
120
 */
template <typename DType>
121
void _Transpose(const DType* in, DType* out, int row, int col) {
122
123
  DType alpha = 1., beta = 0.;
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
124
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
125
  if (!thr_entry->cublas_handle)
sangwzh's avatar
sangwzh committed
126
127
    CUBLAS_CALL(hipblasCreate(&(thr_entry->cublas_handle)));
  CUBLAS_CALL(hipblasSetStream(thr_entry->cublas_handle, stream));
128
  CUBLAS_CALL(Xgeam<DType>(
sangwzh's avatar
sangwzh committed
129
      thr_entry->cublas_handle, HIPBLAS_OP_T, HIPBLAS_OP_N, row, col, &alpha, in,
130
      col, &beta, nullptr, row, out, row));
131
132
}

133
/**
134
135
 * @brief Tranpose the input matrix for data type half.
 * @note cuBLAS has no geam API for half data type, fallback to our kernel.
136
137
 */
template <>
138
void _Transpose<__half>(const __half* in, __half* out, int row, int col) {
sangwzh's avatar
sangwzh committed
139
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
140
141
  int nt = FindNumThreads(row);
  int nb = col;
142
  CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
143
144
}

145
#if BF16_ENABLED
146
/**
147
148
 * @brief Tranpose the input matrix for data type half.
 * @note cuBLAS has no geam API for bf16 data type, fallback to our kernel.
149
150
 */
template <>
sangwzh's avatar
sangwzh committed
151
152
153
void _Transpose<__hip_bfloat16>(
    const __hip_bfloat16* in, __hip_bfloat16* out, int row, int col) {
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
154
155
156
157
158
159
  int nt = FindNumThreads(row);
  int nb = col;
  CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
}
#endif  // BF16_ENABLED

sangwzh's avatar
sangwzh committed
160
#if DTKRT_VERSION < 11000
161
template <typename DType>
sangwzh's avatar
sangwzh committed
162
163
164
165
hipsparseStatus_t Xcsrmm2(
    hipsparseHandle_t handle, hipsparseOperation_t transA,
    hipsparseOperation_t transB, int m, int n, int k, int nnz,
    const DType* alpha, const hipsparseMatDescr_t descrA, const DType* csrValA,
166
167
    const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb,
    const DType* beta, DType* C, int ldc) {
168
  LOG(INFO) << "Not supported dtype";
sangwzh's avatar
sangwzh committed
169
  return HIPSPARSE_STATUS_EXECUTION_FAILED;
170
171
172
}

template <>
sangwzh's avatar
sangwzh committed
173
174
175
176
hipsparseStatus_t Xcsrmm2<float>(
    hipsparseHandle_t handle, hipsparseOperation_t transA,
    hipsparseOperation_t transB, int m, int n, int k, int nnz,
    const float* alpha, const hipsparseMatDescr_t descrA, const float* csrValA,
177
178
    const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb,
    const float* beta, float* C, int ldc) {
sangwzh's avatar
sangwzh committed
179
  return hipsparseScsrmm2(
180
181
      handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
      csrColIndA, B, ldb, beta, C, ldc);
182
183
184
}

template <>
sangwzh's avatar
sangwzh committed
185
186
187
188
hipsparseStatus_t Xcsrmm2<double>(
    hipsparseHandle_t handle, hipsparseOperation_t transA,
    hipsparseOperation_t transB, int m, int n, int k, int nnz,
    const double* alpha, const hipsparseMatDescr_t descrA, const double* csrValA,
189
190
    const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb,
    const double* beta, double* C, int ldc) {
sangwzh's avatar
sangwzh committed
191
  return hipsparseDcsrmm2(
192
193
      handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
      csrColIndA, B, ldb, beta, C, ldc);
194
195
196
}
#endif

197
/** Cusparse implementation of SpMM on Csr format. */
198
199
template <typename DType, typename IdType>
void CusparseCsrmm2(
200
201
    const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
    const DType* A_data, DType* C_data, int x_length) {
202
  // We use csrmm2 to perform following operation:
203
204
205
206
207
208
  // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
  // for node feature tensor. However, since cusparse only supports
  // column-major, while our tensor is stored in row-major, the actual
  // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
  // implement transposition and allocate intermediate workspace memory for
  // this.
209
210
211
212
213
214
215
216
217
  const int m = csr.num_rows;
  const int n = x_length;
  const int k = csr.num_cols;
  const int nnz = csr.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
  // device
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
218
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
219
220
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
221
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
222
  }
sangwzh's avatar
sangwzh committed
223
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
224
225
226
  // all one data array
  DType* valptr = nullptr;
  if (!A_data) {
227
228
    valptr =
        static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
229
230
    _Fill(valptr, nnz, static_cast<DType>(1.));
  }
sangwzh's avatar
sangwzh committed
231
232
233
#if DTKRT_VERSION >= 11000
  hipsparseSpMatDescr_t matA;
  hipsparseDnMatDescr_t matB, matC;
234
235
  constexpr auto dtype = cuda_dtype<DType>::value;
  constexpr auto idtype = cusparse_idtype<IdType>::value;
sangwzh's avatar
sangwzh committed
236
  CUSPARSE_CALL(hipsparseCreateCsr(
237
      &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
238
      static_cast<IdType*>(csr.indices->data),
239
      const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
sangwzh's avatar
sangwzh committed
240
241
242
      HIPSPARSE_INDEX_BASE_ZERO, dtype));
  CUSPARSE_CALL(hipsparseCreateDnMat(
      &matB, k, n, n, const_cast<DType*>(B_data), dtype, HIPSPARSE_ORDER_ROW));
243
  CUSPARSE_CALL(
sangwzh's avatar
sangwzh committed
244
      hipsparseCreateDnMat(&matC, m, n, n, C_data, dtype, HIPSPARSE_ORDER_ROW));
245

sangwzh's avatar
sangwzh committed
246
247
  auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE;
248
  size_t workspace_size;
sangwzh's avatar
sangwzh committed
249
  CUSPARSE_CALL(hipsparseSpMM_bufferSize(
250
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwzh's avatar
sangwzh committed
251
      matC, dtype, HIPSPARSE_SPMM_CSR_ALG2, &workspace_size));
252
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
sangwzh's avatar
sangwzh committed
253
  CUSPARSE_CALL(hipsparseSpMM(
254
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwzh's avatar
sangwzh committed
255
      matC, dtype, HIPSPARSE_SPMM_CSR_ALG2, workspace));
256
257
  device->FreeWorkspace(ctx, workspace);

sangwzh's avatar
sangwzh committed
258
259
260
  CUSPARSE_CALL(hipsparseDestroySpMat(matA));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matB));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matC));
261
262
#else
  // allocate matrix for temporary transposed output
263
264
  DType* trans_out =
      static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
265

sangwzh's avatar
sangwzh committed
266
267
268
269
  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
270
  CUSPARSE_CALL(Xcsrmm2<DType>(
sangwzh's avatar
sangwzh committed
271
272
      thr_entry->cusparse_handle, HIPSPARSE_OPERATION_NON_TRANSPOSE,
      HIPSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
273
274
275
      (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
      static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, trans_out,
      m));
sangwzh's avatar
sangwzh committed
276
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
277
278
279
280
  // transpose the output matrix
  _Transpose(trans_out, C_data, n, m);
  device->FreeWorkspace(ctx, trans_out);
#endif
281
  if (valptr) device->FreeWorkspace(ctx, valptr);
282
283
}

284
/** Cusparse implementation of SpMM on Csr format. */
285
286
template <typename DType, typename IdType>
void CusparseCsrmm2Hetero(
287
288
    const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
    const DType* A_data, DType* C_data, int64_t x_length,
sangwzh's avatar
sangwzh committed
289
    hipStream_t strm_id) {
290
  // We use csrmm2 to perform following operation:
291
292
293
294
295
296
  // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
  // for node feature tensor. However, since cusparse only supports
  // column-major, while our tensor is stored in row-major, the actual
  // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
  // implement transposition and allocate intermediate workspace memory for
  // this.
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  int int_maxlimit = std::numeric_limits<int>::max();
  CHECK_GE(int_maxlimit, (csr.num_rows));
  CHECK_GE(int_maxlimit, csr.num_cols);
  CHECK_GE(int_maxlimit, csr.indices->shape[0]);
  const int m = csr.num_rows;
  const int n = x_length;
  const int k = csr.num_cols;
  const int nnz = csr.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 1.0;
  // device
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
312
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
313
  }
sangwzh's avatar
sangwzh committed
314
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, strm_id));
315
316
317
  // all one data array
  DType* valptr = nullptr;
  if (!A_data) {
318
319
    valptr =
        static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
320
321
    _Fill(valptr, nnz, static_cast<DType>(1.));
  }
sangwzh's avatar
sangwzh committed
322
323
324
#if DTKRT_VERSION >= 11000
  hipsparseSpMatDescr_t matA;
  hipsparseDnMatDescr_t matB, matC;
325
326
  constexpr auto dtype = cuda_dtype<DType>::value;
  constexpr auto idtype = cusparse_idtype<IdType>::value;
sangwzh's avatar
sangwzh committed
327
  CUSPARSE_CALL(hipsparseCreateCsr(
328
      &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
329
      static_cast<IdType*>(csr.indices->data),
330
      const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
sangwzh's avatar
sangwzh committed
331
332
333
      HIPSPARSE_INDEX_BASE_ZERO, dtype));
  CUSPARSE_CALL(hipsparseCreateDnMat(
      &matB, k, n, n, const_cast<DType*>(B_data), dtype, HIPSPARSE_ORDER_ROW));
334
  CUSPARSE_CALL(
sangwzh's avatar
sangwzh committed
335
      hipsparseCreateDnMat(&matC, m, n, n, C_data, dtype, HIPSPARSE_ORDER_ROW));
336

sangwzh's avatar
sangwzh committed
337
338
  auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE;
339
  size_t workspace_size;
sangwzh's avatar
sangwzh committed
340
  CUSPARSE_CALL(hipsparseSpMM_bufferSize(
341
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwzh's avatar
sangwzh committed
342
      matC, dtype, HIPSPARSE_SPMM_CSR_ALG2, &workspace_size));
343
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
sangwzh's avatar
sangwzh committed
344
  CUSPARSE_CALL(hipsparseSpMM(
345
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
sangwzh's avatar
sangwzh committed
346
      matC, dtype, HIPSPARSE_SPMM_CSR_ALG2, workspace));
347
348
  device->FreeWorkspace(ctx, workspace);

sangwzh's avatar
sangwzh committed
349
350
351
  CUSPARSE_CALL(hipsparseDestroySpMat(matA));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matB));
  CUSPARSE_CALL(hipsparseDestroyDnMat(matC));
352
#else
sangwzh's avatar
sangwzh committed
353
354
355
356
  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
357
358
  CHECK_EQ(sizeof(IdType), sizeof(int32_t));
  CUSPARSE_CALL(Xcsrmm2<DType>(
sangwzh's avatar
sangwzh committed
359
360
      thr_entry->cusparse_handle, HIPSPARSE_OPERATION_NON_TRANSPOSE,
      HIPSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
361
362
      (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
      static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, C_data, m));
sangwzh's avatar
sangwzh committed
363
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
364
#endif
365
  if (valptr) device->FreeWorkspace(ctx, valptr);
366
367
368
369
}

}  // namespace

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
#define SWITCH_OP(op, Op, ...)                                  \
  do {                                                          \
    if ((op) == "add") {                                        \
      typedef cuda::binary::Add<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "sub") {                                 \
      typedef cuda::binary::Sub<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "mul") {                                 \
      typedef cuda::binary::Mul<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "div") {                                 \
      typedef cuda::binary::Div<DType> Op;                      \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "copy_lhs") {                            \
      typedef cuda::binary::CopyLhs<DType> Op;                  \
      { __VA_ARGS__ }                                           \
    } else if ((op) == "copy_rhs") {                            \
      typedef cuda::binary::CopyRhs<DType> Op;                  \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
      LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
    }                                                           \
393
394
  } while (0)

395
396
namespace cuda {

397
/**
398
399
 * @brief CUDA kernel of g-SpMM on Coo format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
400
 *       is responsible for the computation on different edges. Threadblocks
401
402
403
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension. To avoid possible data hazards, it uses
 * atomic operators for reduction.
404
 */
405
406
407
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
408
__global__ void SpMMCooKernel(
409
410
411
412
413
414
415
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    const Idx* __restrict__ row, const Idx* __restrict__ col,
    const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len) {
416
417
418
419
420
421
422
423
424
  // SPMM with COO.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
425
426
    const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
    const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    DType* outoff = out + dst * out_len;
    while (tx < out_len) {
      const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
      const int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
      DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
      Idx* arguoff = nullptr;  // arguoff is not used in SpMMCoo.
      Idx* argeoff = nullptr;  // argeoff is not used in SpMMCoo.
      ReduceOp::Call(outoff + tx, arguoff, argeoff, val, src, eid);
      tx += stride_x;
    }
    ty += stride_y;
  }
}

441
/**
442
443
 * @brief CUDA kernel to compute argu and arge in g-SpMM on Coo format.
 * @note it uses edge parallel strategy, different threadblocks (on y-axis)
444
 *       is responsible for the computation on different edges. Threadblocks
445
446
 *       on the x-axis are responsible for the computation on different
 * positions in feature dimension.
447
 */
448
449
450
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
451
__global__ void ArgSpMMCooKernel(
452
453
454
455
456
457
458
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    const Idx* __restrict__ row, const Idx* __restrict__ col,
    const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len) {
459
460
461
462
463
464
465
466
467
  // SPMM with COO arg max/min.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
468
469
    const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
    const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
470
    const DType* outoff = out + dst * out_len;
471
472
    Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len) : nullptr;
    Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len) : nullptr;
473
474
475
476
477
478
479
480
481
482
483
    while (tx < out_len) {
      int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
      int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
      DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
      ReduceOp::CallArg(tx, arguoff, argeoff, val, outoff[tx], src, eid);
      tx += stride_x;
    }
    ty += stride_y;
  }
}

484
/**
485
486
 * @brief CUDA kernel of g-SpMM on Csr format.
 * @note it uses node parallel strategy, different threadblocks (on y-axis)
487
 *       is responsible for the computation on different destination nodes.
488
489
490
 *       Threadblocks on the x-axis are responsible for the computation on
 *       different positions in feature dimension.
 */
491
492
493
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
494
__global__ void SpMMCsrKernel(
495
496
497
498
499
500
501
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
    const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len) {
502
  // SPMM with CSR.
503
504
505
  int ty = blockIdx.x * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.x;
  const int stride_x = blockDim.x * gridDim.y;
506
  while (ty < num_rows) {
507
    int tx = blockIdx.y * blockDim.x + threadIdx.x;
508
    while (tx < out_len) {
509
      typename accum_dtype<DType>::type local_accum = ReduceOp::zero();
510
511
512
513
514
515
      Idx local_argu = 0, local_arge = 0;
      const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
      const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
      for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
        const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
        const Idx cid = _ldg(indices + i);
516
517
518
519
        const DType* uoff =
            BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
        const DType* eoff =
            BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
520
521
522
        DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
        ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
      }
523
524
      // The use of += is to compute cross-type reducing on heterogeneous graph
      // when reduce op is `sum`.
525
      //     C = SpMM(SpA, B) + C
526
527
528
      // Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and
      // min-reducer. It does not affect the output on homogeneous graph as
      // `out` is initialized to zero.
529
      out[ty * out_len + tx] += static_cast<DType>(local_accum);
530
531
532
533
534
535
536
537
538
539
      if (ReduceOp::require_arg && BinaryOp::use_lhs)
        arg_u[ty * out_len + tx] = local_argu;
      if (ReduceOp::require_arg && BinaryOp::use_rhs)
        arg_e[ty * out_len + tx] = local_arge;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

540
/**
541
542
 * @brief CUDA kernel of SpMM-Min/Max on Csr format.
 * @note it uses node parallel strategy, different threadblocks (on y-axis)
543
544
545
546
 *       is responsible for the computation on different destination nodes.
 *       Threadblocks on the x-axis are responsible for the computation on
 *       different positions in feature dimension.
 */
547
548
549
template <
    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
    bool UseBcast = false, bool UseIdx = false>
550
__global__ void SpMMCmpCsrHeteroKernel(
551
552
553
554
555
556
557
558
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
    Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype,
    const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
    const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
    const int64_t* __restrict__ ubcast_off,
    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
    int64_t efeat_len, int64_t out_len, const int src_type, const int etype) {
559
560
561
562
563
564
565
  // SPMM with CSR.
  int ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  const int stride_x = blockDim.x * gridDim.x;
  while (ty < num_rows) {
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    while (tx < out_len) {
566
567
568
      using accum_type = typename accum_dtype<DType>::type;
      accum_type local_accum = static_cast<accum_type>(
          out[ty * out_len + tx]);  // ReduceOp::zero();
569
570
571
572
573
574
      Idx local_argu = 0, local_arge = 0;
      const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
      const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
      for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
        const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
        const Idx cid = _ldg(indices + i);
575
576
577
578
        const DType* uoff =
            BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
        const DType* eoff =
            BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
579
        DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
580
581
        ReduceOp::Call(
            &local_accum, &local_argu, &local_arge, tmp_out, cid, eid);
582
      }
583
584
      // Update output only when max/min values are different that original
      // output
585
      DType new_out = static_cast<DType>(local_accum);
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
      if (out[ty * out_len + tx] != new_out) {
        out[ty * out_len + tx] = new_out;
        if (ReduceOp::require_arg && BinaryOp::use_lhs) {
          arg_u[ty * out_len + tx] = local_argu;
          arg_u_ntype[ty * out_len + tx] = src_type;
        }
        if (ReduceOp::require_arg && BinaryOp::use_rhs) {
          arg_e[ty * out_len + tx] = local_arge;
          arg_e_etype[ty * out_len + tx] = etype;
        }
      }
      tx += stride_x;
    }
    ty += stride_y;
  }
}

603
/**
604
605
606
607
608
609
610
 * @brief CUDA implementation of g-SpMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The Coo matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
611
 *        correspond to the minimum/maximum values of reduction result on
612
613
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
614
 * @param arge Arg-Min/Max on edges. which refers the source node indices
615
 *        correspond to the minimum/maximum values of reduction result on
616
617
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
618
 */
619
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
620
void SpMMCoo(
621
    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
622
    NDArray out, NDArray argu, NDArray arge) {
623
624
625
626
627
628
629
  /**
   * TODO(Xin): Disable half precision for SpMMCoo due to the round-off error.
   * We should use fp32 for the accumulation but it's hard to modify the 
   * current implementation.
   */
#if BF16_ENABLED
  if (std::is_same<DType, __half>::value ||
sangwzh's avatar
sangwzh committed
630
      std::is_same<DType, __hip_bfloat16>::value)
631
632
633
634
635
636
#else
  if (std::is_same<DType, __half>::value)
#endif  // BF16_ENABLED
    LOG(FATAL) << "SpMMCoo doesn't support half precision fow now. "
               << "Please use SpMMCsr instead by allowing the graph "
               << "materialize CSR/CSC formats.";
637
  const Idx *row = coo.row.Ptr<Idx>(), *col = coo.col.Ptr<Idx>(),
638
639
640
            *edge_map = coo.data.Ptr<Idx>();
  const DType *ufeat_data = ufeat.Ptr<DType>(),
              *efeat_data = efeat.Ptr<DType>();
641
642
  DType* out_data = out.Ptr<DType>();
  Idx *argu_data = argu.Ptr<Idx>(), *arge_data = arge.Ptr<Idx>();
sangwzh's avatar
sangwzh committed
643
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
644
645
646
  const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];

  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
647
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
648
649
650
651

  int64_t out_size = out.NumElements();
  const int nt = FindNumThreads(out_size);
  const int nb = (out_size + nt - 1) / nt;
652
653
  CUDA_KERNEL_CALL(
      _FillKernel, nb, nt, 0, stream, out_data, out_size, ReduceOp::zero());
654
655
656
657
658
659
660
661
662
663

  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(coo.data);

  BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
664
665
666
667
668
    CUDA_KERNEL_CALL(
        (SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
        nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
        arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off, lhs_len,
        rhs_len, len);
669
    if (ReduceOp::require_arg) {
670
671
672
673
      CUDA_KERNEL_CALL(
          (ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
          arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off,
674
          lhs_len, rhs_len, len);
675
676
677
678
    }
  });
}

679
/**
680
681
682
683
684
685
686
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
687
 *        correspond to the minimum/maximum values of reduction result on
688
689
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
690
 * @param arge Arg-Min/Max on edges. which refers the source node indices
691
 *        correspond to the minimum/maximum values of reduction result on
692
693
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
694
 */
695
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
696
void SpMMCsr(
697
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
698
    NDArray out, NDArray argu, NDArray arge) {
699
700
701
702
703
704
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const Idx* edge_map = csr.data.Ptr<Idx>();
  const DType* ufeat_data = ufeat.Ptr<DType>();
  const DType* efeat_data = efeat.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
705
706
707
  Idx* argu_data = argu.Ptr<Idx>();
  Idx* arge_data = arge.Ptr<Idx>();

sangwzh's avatar
sangwzh committed
708
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
709
710

  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
711
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
712
713
  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
714
  const int nby = (len + ntx - 1) / ntx;
715
  const int nbx = FindNumBlocks<'x'>((csr.num_rows + nty - 1) / nty);
716
717
718
719
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

720
721
722
723
724
725
726
  BCAST_IDX_CTX_SWITCH(
      bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
      {CUDA_KERNEL_CALL(
          (SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
          arge_data, indptr, indices, edge_map, csr.num_rows, csr.num_cols,
          ubcast_off, ebcast_off, lhs_len, rhs_len, len)});
727
728
}

729
/**
730
731
732
733
734
735
736
 * @brief CUDA kernel of SpMM-Min/Max on Csr format on heterogeneous graph.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param ufeat The feature on source nodes.
 * @param efeat The feature on edges.
 * @param out The result feature on destination nodes.
 * @param argu Arg-Min/Max on source nodes, which refers the source node indices
737
 *        correspond to the minimum/maximum values of reduction result on
738
739
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
740
 * @param arge Arg-Min/Max on edges. which refers the source node indices
741
 *        correspond to the minimum/maximum values of reduction result on
742
743
744
745
746
747
748
749
750
751
 *        destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
 * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
 * the source node types correspond to the minimum/maximum values of reduction
 * result on destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
 * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
 * source node indices correspond to the minimum/maximum values of reduction
 * result on destination nodes. It's useful in computing gradients of Min/Max
 * reducer.
752
753
 * @param src_type Node type of the source nodes of an etype
 * @param etype Edge type
754
 */
755
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
756
void SpMMCmpCsrHetero(
757
758
759
760
761
762
763
764
765
    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
    NDArray arge_etype, const int src_type, const int etype) {
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const Idx* edge_map = csr.data.Ptr<Idx>();
  const DType* ufeat_data = ufeat.Ptr<DType>();
  const DType* efeat_data = efeat.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
766
767
768
  Idx* argu_data = argu.Ptr<Idx>();
  Idx* arge_data = arge.Ptr<Idx>();

sangwzh's avatar
sangwzh committed
769
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
770
771

  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
772
  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
773
774
775
776
777
778
779
780
  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

781
782
783
784
785
786
787
788
789
790
  BCAST_IDX_CTX_SWITCH(
      bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
      {CUDA_KERNEL_CALL(
          (SpMMCmpCsrHeteroKernel<
              Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
          arge_data, static_cast<Idx*>(argu_ntype->data),
          static_cast<Idx*>(arge_etype->data), indptr, indices, edge_map,
          csr.num_rows, csr.num_cols, ubcast_off, ebcast_off, lhs_len, rhs_len,
          len, src_type, etype)});
791
792
}

793
794
795
796
}  // namespace cuda
}  // namespace aten
}  // namespace dgl

797
#endif  // DGL_ARRAY_CUDA_SPMM_CUH_