csr_mm.cu 14.2 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/csr_mm.cu
 * @brief SpSpMM/SpGEMM C APIs and definitions.
5
6
7
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
8

9
10
#include <limits>

11
#include "../../runtime/cuda/cuda_common.h"
12
13
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
14
15
16
17
18
19
20
namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

21
#if CUDART_VERSION >= 12000
22

23
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 12.0+ */
24
25
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
26
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
27
28
29
30
31
32
33
34
35
36
37
38
39
    const NDArray B_weights_array) {
  // We use Spgemm (SpSpMM) to perform following operation:
  // C = A x B, where A, B and C are sparse matrices in csr format.
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
  auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
  // device
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
40
  cudaStream_t stream = runtime::getCurrentCUDAStream();
41
42
43
44
45
46
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
47
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
48
49
  // all one data array
  cusparseSpMatDescr_t matA, matB, matC;
50
51
  IdArray dC_csrOffsets =
      IdArray::Empty({A.num_rows + 1}, A.indptr->dtype, A.indptr->ctx);
52
53
54
55
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  constexpr auto idtype = cusparse_idtype<IdType>::value;
  constexpr auto dtype = cuda_dtype<DType>::value;
  // Create sparse matrix A, B and C in CSR format
56
  CUSPARSE_CALL(cusparseCreateCsr(
57
58
      &matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(),
59
      // cusparseCreateCsr only accepts non-const pointers.
60
61
      const_cast<DType*>(A_weights), idtype, idtype, CUSPARSE_INDEX_BASE_ZERO,
      dtype));
62
  CUSPARSE_CALL(cusparseCreateCsr(
63
64
      &matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(),
65
      // cusparseCreateCsr only accepts non-const pointers.
66
67
      const_cast<DType*>(B_weights), idtype, idtype, CUSPARSE_INDEX_BASE_ZERO,
      dtype));
68
  CUSPARSE_CALL(cusparseCreateCsr(
69
70
      &matC, A.num_rows, B.num_cols, 0, dC_csrOffsets_data, nullptr, nullptr,
      idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
71
72
  // SpGEMM Computation
  cusparseSpGEMMDescr_t spgemmDesc;
73
74
  cusparseSpGEMMAlg_t alg = CUSPARSE_SPGEMM_DEFAULT;

75
  CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));
76
  size_t workspace_size1 = 0, workspace_size2 = 0, workspace_size3 = 0;
77
78
  // ask bufferSize1 bytes for external memory
  CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
79
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
80
      matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
81
82
  void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
  // inspect the matrices A and B to understand the memory requiremnent
83
84
85
  cusparseStatus_t e = cusparseSpGEMM_workEstimation(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1);
86
87
88
89
90
91
92
  // CUSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1
  // and throws insufficient memory error within workEstimation call
  if (e == CUSPARSE_STATUS_INSUFFICIENT_RESOURCES) {
    // fall back to ALG2 to estimate num_prods
    alg = CUSPARSE_SPGEMM_ALG2;
    device->FreeWorkspace(ctx, workspace1);
    // rerun cusparseSpGEMM_workEstimation
93
94
95
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
96
    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
97
98
99
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));
100
101
102
103
104
105
106
107
108
109
110
111
112
  } else {
    CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR in SpGEMM: " << e;
  }

  // get the number of intermediate products required for SpGEMM compute
  // num_prods indicates device memory consumption for SpGEMM if using ALG2/3
  int64_t num_prods;
  CUSPARSE_CALL(cusparseSpGEMM_getNumProducts(spgemmDesc, &num_prods));

  // assume free GPU mem at least ~15G for below heuristics to work
  // user-defined medium problem size (below will use DEFAULT)
  int64_t MEDIUM_NUM_PRODUCTS = 400000000;  // 400*1000*1000;
  // user-defined large problem size (above will use ALG3)
113
  int64_t LARGE_NUM_PRODUCTS = 800000000;  // 800*1000*1000;
114
115
116
117

  // switch to ALG2/ALG3 for medium & large problem size
  if (alg == CUSPARSE_SPGEMM_DEFAULT && num_prods > MEDIUM_NUM_PRODUCTS) {
    // use ALG3 for very large problem
118
119
    alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3
                                         : CUSPARSE_SPGEMM_ALG2;
120
121
122

    device->FreeWorkspace(ctx, workspace1);
    // rerun cusparseSpGEMM_workEstimation
123
124
125
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
126
    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
127
128
129
    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));
130
131
132
133
134
135
136
137
138
  } else if (alg == CUSPARSE_SPGEMM_ALG2 && num_prods > LARGE_NUM_PRODUCTS) {
    // no need to rerun cusparseSpGEMM_workEstimation between ALG2 and ALG3
    alg = CUSPARSE_SPGEMM_ALG3;
  }

  if (alg == CUSPARSE_SPGEMM_ALG2 || alg == CUSPARSE_SPGEMM_ALG3) {
    // estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3
    // reduce chunk_fraction if crash due to mem., but it trades off speed
    float chunk_fraction = num_prods < 4 * LARGE_NUM_PRODUCTS ? 0.15 : 0.05;
139
140
141
142
    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3, NULL,
        NULL));
143
    void* workspace3 = (device->AllocWorkspace(ctx, workspace_size3));
144
145
146
147
    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3,
        workspace3, &workspace_size2));
148
149
    device->FreeWorkspace(ctx, workspace3);
  } else {
150
151
152
    CUSPARSE_CALL(cusparseSpGEMM_compute(
        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
        matC, dtype, alg, spgemmDesc, &workspace_size2, NULL));
153
  }
154
155
156
  // ask bufferSize2 bytes for external memory
  void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
  // compute the intermediate product of A * B
157
158
  CUSPARSE_CALL(cusparseSpGEMM_compute(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
159
      matC, dtype, alg, spgemmDesc, &workspace_size2, workspace2));
160
161
  // get matrix C non-zero entries C_nnz1
  int64_t C_num_rows1, C_num_cols1, C_nnz1;
162
163
  CUSPARSE_CALL(
      cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
164
  IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
165
166
  NDArray dC_weights =
      NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
167
168
169
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  // update matC with the new pointers
170
171
  CUSPARSE_CALL(cusparseCsrSetPointers(
      matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data));
172
  // copy the final products to the matrix C
173
174
  CUSPARSE_CALL(cusparseSpGEMM_copy(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
175
      matC, dtype, alg, spgemmDesc));
176
177
178
179
180
181
182
183

  device->FreeWorkspace(ctx, workspace1);
  device->FreeWorkspace(ctx, workspace2);
  // destroy matrix/vector descriptors
  CUSPARSE_CALL(cusparseSpGEMM_destroyDescr(spgemmDesc));
  CUSPARSE_CALL(cusparseDestroySpMat(matA));
  CUSPARSE_CALL(cusparseDestroySpMat(matB));
  CUSPARSE_CALL(cusparseDestroySpMat(matC));
184
  return {
185
186
187
      CSRMatrix(
          A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
188
      dC_weights};
189
190
}

191
#else  // CUDART_VERSION < 12000
192

193
194
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
 * versions */
195
196
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
197
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
198
199
200
201
202
203
204
205
206
207
208
209
210
    const NDArray B_weights_array) {
  int nnzC;
  csrgemm2Info_t info = nullptr;
  size_t workspace_size;
  const DType alpha = 1.;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int k = B.num_cols;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
211
  cudaStream_t stream = runtime::getCurrentCUDAStream();
212
213
214
215
216
217
218
  auto idtype = A.indptr->dtype;
  auto dtype = A_weights_array->dtype;
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
219
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
220
221
222
223
224
225
226
227
228
  CUSPARSE_CALL(cusparseSetPointerMode(
      thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST));

  CUSPARSE_CALL(cusparseCreateCsrgemm2Info(&info));

  cusparseMatDescr_t matA, matB, matC, matD;
  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
229
  CUSPARSE_CALL(cusparseCreateMatDescr(&matD));  // needed even if D is null
230

231
232
233
234
235
  CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, info, &workspace_size));
236

237
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
238
  IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
239
240
241
242
243
  CUSPARSE_CALL(CSRGEMM<DType>::nnz(
      thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,
      C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
244
245
246

  IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
  NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
247
248
249
250
251
252
  CUSPARSE_CALL(CSRGEMM<DType>::compute(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),
      C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));
253
254
255
256
257
258
259
260

  device->FreeWorkspace(ctx, workspace);
  CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matD));

261
  return {
262
263
      CSRMatrix(
          m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
264
      C_weights};
265
266
}

267
#endif  // CUDART_VERSION >= 12000
268
269
270
271
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
272
    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
273
274
275
276
277
278
279
280
281
    NDArray B_weights) {
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  CSRMatrix newA, newB;
  bool cast = false;

  // Cast 64 bit indices to 32 bit.
  if (A.indptr->dtype.bits == 64) {
    newA = CSRMatrix(
282
283
        A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),
        AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
284
    newB = CSRMatrix(
285
286
        B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),
        AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
287
288
289
290
291
    cast = true;
  }

  // Reorder weights if A or B has edge IDs
  NDArray newA_weights, newB_weights;
292
293
  if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);
  if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);
294
295
296
297
298
299
300
301
302

  auto result = cusparse::CusparseSpgemm<DType, int32_t>(
      cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
      cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
303
304
305
306
        CSRMatrix(
            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
            AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),
        result.second};
307
308
309
310
311
  } else {
    return result;
  }
}

312
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
313
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
314
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
315
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
316
317
318
319
320
321
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __nv_bfloat16>(
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __nv_bfloat16>(
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif  // BF16_ENABLED
322
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
323
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
324
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
325
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
326
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
327
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
328
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
329
330
331
332
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);

}  // namespace aten
}  // namespace dgl