csr_mm.cu 15.6 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/csr_mm.cu
 * @brief SpSpMM/SpGEMM C APIs and definitions.
5
6
7
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
8

9
10
#include <limits>

11
#include "../../runtime/cuda/cuda_common.h"
12
13
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
14
15
16
17
18
19
20
namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

21
#if CUDART_VERSION >= 12000
22

23
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 12.0+ */
24
25
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
26
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
27
28
29
30
31
32
33
34
35
36
37
38
39
    const NDArray B_weights_array) {
  // We use Spgemm (SpSpMM) to perform following operation:
  // C = A x B, where A, B and C are sparse matrices in csr format.
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
  auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
  // device
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
40
  cudaStream_t stream = runtime::getCurrentCUDAStream();
41
42
43
44
45
46
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
47
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
48
49
  // all one data array
  cusparseSpMatDescr_t matA, matB, matC;
50
51
  IdArray dC_csrOffsets =
      IdArray::Empty({A.num_rows + 1}, A.indptr->dtype, A.indptr->ctx);
52
53
54
55
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  constexpr auto idtype = cusparse_idtype<IdType>::value;
  constexpr auto dtype = cuda_dtype<DType>::value;
  // Create sparse matrix A, B and C in CSR format
56
  CUSPARSE_CALL(cusparseCreateCsr(
57
58
      &matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(),
59
60
      // cusparseCreateCsr only accepts non-const pointers.
      const_cast<DType*>(A_weights),
61
      idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
62
  CUSPARSE_CALL(cusparseCreateCsr(
63
64
      &matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(),
65
66
      // cusparseCreateCsr only accepts non-const pointers.
      const_cast<DType*>(B_weights),
67
      idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
68
69
70
  CUSPARSE_CALL(cusparseCreateCsr(
      &matC, A.num_rows, B.num_cols, 0, nullptr, nullptr, nullptr, idtype,
      idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
71
72
  // SpGEMM Computation
  cusparseSpGEMMDescr_t spgemmDesc;
73
74
  cusparseSpGEMMAlg_t alg = CUSPARSE_SPGEMM_DEFAULT;

75
  CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));
76
  size_t workspace_size1 = 0, workspace_size2 = 0, workspace_size3 = 0;
77
78
  // ask bufferSize1 bytes for external memory
  CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
79
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
80
      matC, dtype, alg, spgemmDesc, &workspace_size1,
81
82
83
      NULL));
  void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
  // inspect the matrices A and B to understand the memory requiremnent
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  cusparseStatus_t e =
    cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle, transA,
                                  transB, &alpha, matA, matB, &beta,
                                  matC, dtype, alg, spgemmDesc,
                                  &workspace_size1, workspace1);
  // CUSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1
  // and throws insufficient memory error within workEstimation call
  if (e == CUSPARSE_STATUS_INSUFFICIENT_RESOURCES) {
    // fall back to ALG2 to estimate num_prods
    alg = CUSPARSE_SPGEMM_ALG2;
    device->FreeWorkspace(ctx, workspace1);
    // rerun cusparseSpGEMM_workEstimation
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
                                                transA, transB, &alpha, matA,
                                                matB, &beta, matC, dtype, alg,
                                                spgemmDesc, &workspace_size1,
                                                NULL));
    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
                                                transA, transB, &alpha, matA,
                                                matB, &beta, matC, dtype, alg,
                                                spgemmDesc, &workspace_size1,
                                                workspace1));
  } else {
    CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR in SpGEMM: " << e;
  }

  // get the number of intermediate products required for SpGEMM compute
  // num_prods indicates device memory consumption for SpGEMM if using ALG2/3
  int64_t num_prods;
  CUSPARSE_CALL(cusparseSpGEMM_getNumProducts(spgemmDesc, &num_prods));

  // assume free GPU mem at least ~15G for below heuristics to work
  // user-defined medium problem size (below will use DEFAULT)
  int64_t MEDIUM_NUM_PRODUCTS = 400000000;  // 400*1000*1000;
  // user-defined large problem size (above will use ALG3)
  int64_t LARGE_NUM_PRODUCTS  = 800000000;  // 800*1000*1000;

  // switch to ALG2/ALG3 for medium & large problem size
  if (alg == CUSPARSE_SPGEMM_DEFAULT && num_prods > MEDIUM_NUM_PRODUCTS) {
    // use ALG3 for very large problem
    alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3 :
      CUSPARSE_SPGEMM_ALG2;

    device->FreeWorkspace(ctx, workspace1);
    // rerun cusparseSpGEMM_workEstimation
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
                                                transA, transB, &alpha, matA,
                                                matB, &beta, matC, dtype, alg,
                                                spgemmDesc, &workspace_size1,
                                                NULL));
    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
                                                transA, transB, &alpha, matA,
                                                matB, &beta, matC, dtype, alg,
                                                spgemmDesc, &workspace_size1,
                                                workspace1));
  } else if (alg == CUSPARSE_SPGEMM_ALG2 && num_prods > LARGE_NUM_PRODUCTS) {
    // no need to rerun cusparseSpGEMM_workEstimation between ALG2 and ALG3
    alg = CUSPARSE_SPGEMM_ALG3;
  }

  if (alg == CUSPARSE_SPGEMM_ALG2 || alg == CUSPARSE_SPGEMM_ALG3) {
    // estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3
    // reduce chunk_fraction if crash due to mem., but it trades off speed
    float chunk_fraction = num_prods < 4 * LARGE_NUM_PRODUCTS ? 0.15 : 0.05;
    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(thr_entry->cusparse_handle,
                                                transA, transB, &alpha, matA,
                                                matB, &beta, matC, dtype, alg,
                                                spgemmDesc, chunk_fraction,
                                                &workspace_size3,
                                                NULL, NULL));
    void* workspace3 = (device->AllocWorkspace(ctx, workspace_size3));
    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(thr_entry->cusparse_handle,
                                                transA, transB, &alpha, matA,
                                                matB, &beta, matC, dtype, alg,
                                                spgemmDesc, chunk_fraction,
                                                &workspace_size3,
                                                workspace3, &workspace_size2));
    device->FreeWorkspace(ctx, workspace3);
  } else {
    CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
                                         transA, transB, &alpha, matA,
                                         matB, &beta, matC, dtype, alg,
                                         spgemmDesc, &workspace_size2,
                                         NULL));
  }
171
172
173
  // ask bufferSize2 bytes for external memory
  void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
  // compute the intermediate product of A * B
174
175
  CUSPARSE_CALL(cusparseSpGEMM_compute(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
176
      matC, dtype, alg, spgemmDesc, &workspace_size2,
177
178
179
      workspace2));
  // get matrix C non-zero entries C_nnz1
  int64_t C_num_rows1, C_num_cols1, C_nnz1;
180
181
  CUSPARSE_CALL(
      cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
182
  IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
183
184
  NDArray dC_weights = NDArray::Empty(
      {C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
185
186
187
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  // update matC with the new pointers
188
189
  CUSPARSE_CALL(cusparseCsrSetPointers(
      matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data));
190
  // copy the final products to the matrix C
191
192
  CUSPARSE_CALL(cusparseSpGEMM_copy(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
193
      matC, dtype, alg, spgemmDesc));
194
195
196
197
198
199
200
201

  device->FreeWorkspace(ctx, workspace1);
  device->FreeWorkspace(ctx, workspace2);
  // destroy matrix/vector descriptors
  CUSPARSE_CALL(cusparseSpGEMM_destroyDescr(spgemmDesc));
  CUSPARSE_CALL(cusparseDestroySpMat(matA));
  CUSPARSE_CALL(cusparseDestroySpMat(matB));
  CUSPARSE_CALL(cusparseDestroySpMat(matC));
202
  return {
203
204
205
      CSRMatrix(
          A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
206
      dC_weights};
207
208
}

209
#else  // CUDART_VERSION < 12000
210

211
212
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
 * versions */
213
214
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
215
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
216
217
218
219
220
221
222
223
224
225
226
227
228
    const NDArray B_weights_array) {
  int nnzC;
  csrgemm2Info_t info = nullptr;
  size_t workspace_size;
  const DType alpha = 1.;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int k = B.num_cols;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
229
  cudaStream_t stream = runtime::getCurrentCUDAStream();
230
231
232
233
234
235
236
  auto idtype = A.indptr->dtype;
  auto dtype = A_weights_array->dtype;
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
237
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
238
239
240
241
242
243
244
245
246
  CUSPARSE_CALL(cusparseSetPointerMode(
      thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST));

  CUSPARSE_CALL(cusparseCreateCsrgemm2Info(&info));

  cusparseMatDescr_t matA, matB, matC, matD;
  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
247
  CUSPARSE_CALL(cusparseCreateMatDescr(&matD));  // needed even if D is null
248

249
250
251
252
253
  CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, info, &workspace_size));
254

255
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
256
  IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
257
258
259
260
261
  CUSPARSE_CALL(CSRGEMM<DType>::nnz(
      thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,
      C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
262
263
264

  IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
  NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
265
266
267
268
269
270
  CUSPARSE_CALL(CSRGEMM<DType>::compute(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),
      C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));
271
272
273
274
275
276
277
278

  device->FreeWorkspace(ctx, workspace);
  CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matD));

279
  return {
280
281
      CSRMatrix(
          m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
282
      C_weights};
283
284
}

285
#endif  // CUDART_VERSION >= 12000
286
287
288
289
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
290
    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
291
292
293
294
295
296
297
298
299
    NDArray B_weights) {
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  CSRMatrix newA, newB;
  bool cast = false;

  // Cast 64 bit indices to 32 bit.
  if (A.indptr->dtype.bits == 64) {
    newA = CSRMatrix(
300
301
        A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),
        AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
302
    newB = CSRMatrix(
303
304
        B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),
        AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
305
306
307
308
309
    cast = true;
  }

  // Reorder weights if A or B has edge IDs
  NDArray newA_weights, newB_weights;
310
311
  if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);
  if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);
312
313
314
315
316
317
318
319
320

  auto result = cusparse::CusparseSpgemm<DType, int32_t>(
      cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
      cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
321
322
323
324
        CSRMatrix(
            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
            AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),
        result.second};
325
326
327
328
329
  } else {
    return result;
  }
}

330
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
331
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
332
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
333
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
334
335
336
337
338
339
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __nv_bfloat16>(
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __nv_bfloat16>(
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif  // BF16_ENABLED
340
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
341
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
342
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
343
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
344
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
345
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
346
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
347
348
349
350
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);

}  // namespace aten
}  // namespace dgl