csr_mm.hip 14.5 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/csr_mm.cu
 * @brief SpSpMM/SpGEMM C APIs and definitions.
7
8
 */
#include <dgl/array.h>
sangwzh's avatar
sangwzh committed
9
10
#include "../../../include/dgl/array.h"

11
#include <dgl/runtime/device_api.h>
12

13
14
#include <limits>

15
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
16
17
#include "cusparse_dispatcher.cuh"
#include "functor.cuh"
18
19
20
21
22
23
24
namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

sangwzh's avatar
sangwzh committed
25
#if DTKRT_VERSION >= 12000
26

27
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 12.0+ */
28
29
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
30
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
31
32
33
34
35
36
37
    const NDArray B_weights_array) {
  // We use Spgemm (SpSpMM) to perform following operation:
  // C = A x B, where A, B and C are sparse matrices in csr format.
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
sangwzh's avatar
sangwzh committed
38
39
  auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE;
40
41
42
43
  // device
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
44
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
45
46
47
48
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
49
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
50
  }
sangwzh's avatar
sangwzh committed
51
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
52
  // all one data array
sangwzh's avatar
sangwzh committed
53
  hipsparseSpMatDescr_t matA, matB, matC;
54
55
  IdArray dC_csrOffsets =
      IdArray::Empty({A.num_rows + 1}, A.indptr->dtype, A.indptr->ctx);
56
57
58
59
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  constexpr auto idtype = cusparse_idtype<IdType>::value;
  constexpr auto dtype = cuda_dtype<DType>::value;
  // Create sparse matrix A, B and C in CSR format
sangwzh's avatar
sangwzh committed
60
  CUSPARSE_CALL(hipsparseCreateCsr(
61
62
      &matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(),
sangwzh's avatar
sangwzh committed
63
64
      // hipsparseCreateCsr only accepts non-const pointers.
      const_cast<DType*>(A_weights), idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO,
65
      dtype));
sangwzh's avatar
sangwzh committed
66
  CUSPARSE_CALL(hipsparseCreateCsr(
67
68
      &matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(),
sangwzh's avatar
sangwzh committed
69
70
      // hipsparseCreateCsr only accepts non-const pointers.
      const_cast<DType*>(B_weights), idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO,
71
      dtype));
sangwzh's avatar
sangwzh committed
72
  CUSPARSE_CALL(hipsparseCreateCsr(
73
      &matC, A.num_rows, B.num_cols, 0, dC_csrOffsets_data, nullptr, nullptr,
sangwzh's avatar
sangwzh committed
74
      idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO, dtype));
75
  // SpGEMM Computation
sangwzh's avatar
sangwzh committed
76
77
  hipsparseSpGEMMDescr_t spgemmDesc;
  cusparseSpGEMMAlg_t alg = HIPSPARSE_SPGEMM_DEFAULT;
78

sangwzh's avatar
sangwzh committed
79
  CUSPARSE_CALL(hipsparseSpGEMM_createDescr(&spgemmDesc));
80
  size_t workspace_size1 = 0, workspace_size2 = 0, workspace_size3 = 0;
81
  // ask bufferSize1 bytes for external memory
sangwzh's avatar
sangwzh committed
82
  CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
83
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
84
      matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
85
86
  void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
  // inspect the matrices A and B to understand the memory requiremnent
sangwzh's avatar
sangwzh committed
87
  hipsparseStatus_t e = hipsparseSpGEMM_workEstimation(
88
89
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1);
sangwzh's avatar
sangwzh committed
90
  // HIPSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1
91
92
93
94
95
  // and throws insufficient memory error within workEstimation call
  if (e == CUSPARSE_STATUS_INSUFFICIENT_RESOURCES) {
    // fall back to ALG2 to estimate num_prods
    alg = CUSPARSE_SPGEMM_ALG2;
    device->FreeWorkspace(ctx, workspace1);
sangwzh's avatar
sangwzh committed
96
97
    // rerun hipsparseSpGEMM_workEstimation
    CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
98
99
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
100
    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
sangwzh's avatar
sangwzh committed
101
    CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
102
103
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));
104
  } else {
sangwzh's avatar
sangwzh committed
105
    CHECK(e == HIPSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR in SpGEMM: " << e;
106
107
108
109
110
111
112
113
114
115
116
  }

  // get the number of intermediate products required for SpGEMM compute
  // num_prods indicates device memory consumption for SpGEMM if using ALG2/3
  int64_t num_prods;
  CUSPARSE_CALL(cusparseSpGEMM_getNumProducts(spgemmDesc, &num_prods));

  // assume free GPU mem at least ~15G for below heuristics to work
  // user-defined medium problem size (below will use DEFAULT)
  int64_t MEDIUM_NUM_PRODUCTS = 400000000;  // 400*1000*1000;
  // user-defined large problem size (above will use ALG3)
117
  int64_t LARGE_NUM_PRODUCTS = 800000000;  // 800*1000*1000;
118
119

  // switch to ALG2/ALG3 for medium & large problem size
sangwzh's avatar
sangwzh committed
120
  if (alg == HIPSPARSE_SPGEMM_DEFAULT && num_prods > MEDIUM_NUM_PRODUCTS) {
121
    // use ALG3 for very large problem
122
123
    alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3
                                         : CUSPARSE_SPGEMM_ALG2;
124
125

    device->FreeWorkspace(ctx, workspace1);
sangwzh's avatar
sangwzh committed
126
127
    // rerun hipsparseSpGEMM_workEstimation
    CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
128
129
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
130
    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
sangwzh's avatar
sangwzh committed
131
    CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
132
133
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));
134
  } else if (alg == CUSPARSE_SPGEMM_ALG2 && num_prods > LARGE_NUM_PRODUCTS) {
sangwzh's avatar
sangwzh committed
135
    // no need to rerun hipsparseSpGEMM_workEstimation between ALG2 and ALG3
136
137
138
139
140
141
142
    alg = CUSPARSE_SPGEMM_ALG3;
  }

  if (alg == CUSPARSE_SPGEMM_ALG2 || alg == CUSPARSE_SPGEMM_ALG3) {
    // estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3
    // reduce chunk_fraction if crash due to mem., but it trades off speed
    float chunk_fraction = num_prods < 4 * LARGE_NUM_PRODUCTS ? 0.15 : 0.05;
143
144
145
146
    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3, NULL,
        NULL));
147
    void* workspace3 = (device->AllocWorkspace(ctx, workspace_size3));
148
149
150
151
    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3,
        workspace3, &workspace_size2));
152
153
    device->FreeWorkspace(ctx, workspace3);
  } else {
sangwzh's avatar
sangwzh committed
154
    CUSPARSE_CALL(hipsparseSpGEMM_compute(
155
156
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size2, NULL));
157
  }
158
159
160
  // ask bufferSize2 bytes for external memory
  void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
  // compute the intermediate product of A * B
sangwzh's avatar
sangwzh committed
161
  CUSPARSE_CALL(hipsparseSpGEMM_compute(
162
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
163
      matC, dtype, alg, spgemmDesc, &workspace_size2, workspace2));
164
165
  // get matrix C non-zero entries C_nnz1
  int64_t C_num_rows1, C_num_cols1, C_nnz1;
166
  CUSPARSE_CALL(
sangwzh's avatar
sangwzh committed
167
      hipsparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
168
  IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
169
170
  NDArray dC_weights =
      NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
171
172
173
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  // update matC with the new pointers
sangwzh's avatar
sangwzh committed
174
  CUSPARSE_CALL(hipsparseCsrSetPointers(
175
      matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data));
176
  // copy the final products to the matrix C
sangwzh's avatar
sangwzh committed
177
  CUSPARSE_CALL(hipsparseSpGEMM_copy(
178
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
179
      matC, dtype, alg, spgemmDesc));
180
181
182
183

  device->FreeWorkspace(ctx, workspace1);
  device->FreeWorkspace(ctx, workspace2);
  // destroy matrix/vector descriptors
sangwzh's avatar
sangwzh committed
184
185
186
187
  CUSPARSE_CALL(hipsparseSpGEMM_destroyDescr(spgemmDesc));
  CUSPARSE_CALL(hipsparseDestroySpMat(matA));
  CUSPARSE_CALL(hipsparseDestroySpMat(matB));
  CUSPARSE_CALL(hipsparseDestroySpMat(matC));
188
  return {
189
190
191
      CSRMatrix(
          A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
192
      dC_weights};
193
194
}

sangwzh's avatar
sangwzh committed
195
#else  // DTKRT_VERSION < 12000
196

197
198
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
 * versions */
199
200
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
201
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
202
203
204
205
206
207
208
209
210
211
212
213
214
    const NDArray B_weights_array) {
  int nnzC;
  csrgemm2Info_t info = nullptr;
  size_t workspace_size;
  const DType alpha = 1.;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int k = B.num_cols;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
215
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
216
217
218
219
220
  auto idtype = A.indptr->dtype;
  auto dtype = A_weights_array->dtype;
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
221
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
222
  }
sangwzh's avatar
sangwzh committed
223
224
225
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
  CUSPARSE_CALL(hipsparseSetPointerMode(
      thr_entry->cusparse_handle, HIPSPARSE_POINTER_MODE_HOST));
226

sangwzh's avatar
sangwzh committed
227
  CUSPARSE_CALL(hipsparseCreateCsrgemm2Info(&info));
228

sangwzh's avatar
sangwzh committed
229
230
231
232
233
  hipsparseMatDescr_t matA, matB, matC, matD;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matA));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matB));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matC));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matD));  // needed even if D is null
234

235
236
237
238
239
  CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, info, &workspace_size));
240

241
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
242
  IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
243
244
245
246
247
  CUSPARSE_CALL(CSRGEMM<DType>::nnz(
      thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,
      C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
248
249
250

  IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
  NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
251
252
253
254
255
256
  CUSPARSE_CALL(CSRGEMM<DType>::compute(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),
      C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));
257
258

  device->FreeWorkspace(ctx, workspace);
sangwzh's avatar
sangwzh committed
259
260
261
262
263
  CUSPARSE_CALL(hipsparseDestroyCsrgemm2Info(info));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matA));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matB));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matC));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matD));
264

265
  return {
266
267
      CSRMatrix(
          m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
268
      C_weights};
269
270
}

sangwzh's avatar
sangwzh committed
271
#endif  // DTKRT_VERSION >= 12000
272
273
274
275
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
276
    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
277
278
279
280
281
282
283
284
285
    NDArray B_weights) {
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  CSRMatrix newA, newB;
  bool cast = false;

  // Cast 64 bit indices to 32 bit.
  if (A.indptr->dtype.bits == 64) {
    newA = CSRMatrix(
286
287
        A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),
        AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
288
    newB = CSRMatrix(
289
290
        B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),
        AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
291
292
293
294
295
    cast = true;
  }

  // Reorder weights if A or B has edge IDs
  NDArray newA_weights, newB_weights;
296
297
  if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);
  if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);
298
299
300
301
302
303
304
305
306

  auto result = cusparse::CusparseSpgemm<DType, int32_t>(
      cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
      cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
307
308
309
310
        CSRMatrix(
            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
            AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),
        result.second};
311
312
313
314
315
  } else {
    return result;
  }
}

316
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
317
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
318
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
319
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
320
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
321
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __hip_bfloat16>(
322
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
sangwzh's avatar
sangwzh committed
323
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __hip_bfloat16>(
324
325
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif  // BF16_ENABLED
326
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
327
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
328
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
329
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
330
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
331
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
332
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
333
334
335
336
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);

}  // namespace aten
}  // namespace dgl