csr_mm.cu 11 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/csr_mm.cu
 * @brief SpSpMM/SpGEMM C APIs and definitions.
5
6
7
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
8

9
#include "../../runtime/cuda/cuda_common.h"
10
11
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
12
13
14
15
16
17
18
19

namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

20
21
#if 0  // disabling CUDA 11.0+ implementation for now because of problems on
       // bigger graphs
22

23
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
24
25
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
26
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
27
28
29
30
31
32
33
34
35
36
37
38
39
    const NDArray B_weights_array) {
  // We use Spgemm (SpSpMM) to perform following operation:
  // C = A x B, where A, B and C are sparse matrices in csr format.
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
  auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
  // device
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
40
  cudaStream_t stream = runtime::getCurrentCUDAStream();
41
42
43
44
45
46
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
47
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
48
49
  // all one data array
  cusparseSpMatDescr_t matA, matB, matC;
50
51
  IdArray dC_csrOffsets =
      IdArray::Empty({A.num_rows + 1}, A.indptr->dtype, A.indptr->ctx);
52
53
54
55
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  constexpr auto idtype = cusparse_idtype<IdType>::value;
  constexpr auto dtype = cuda_dtype<DType>::value;
  // Create sparse matrix A, B and C in CSR format
56
57
  CUSPARSE_CALL(cusparseCreateCsr(
      &matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<DType>(),
58
      A.indices.Ptr<DType>(),
59
60
      // cusparseCreateCsr only accepts non-const pointers.
      const_cast<DType*>(A_weights),
61
      idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
62
63
  CUSPARSE_CALL(cusparseCreateCsr(
      &matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<DType>(),
64
      B.indices.Ptr<DType>(),
65
66
      // cusparseCreateCsr only accepts non-const pointers.
      const_cast<DType*>(B_weights),
67
      idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
68
69
70
  CUSPARSE_CALL(cusparseCreateCsr(
      &matC, A.num_rows, B.num_cols, 0, nullptr, nullptr, nullptr, idtype,
      idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
71
72
73
74
75
76
  // SpGEMM Computation
  cusparseSpGEMMDescr_t spgemmDesc;
  CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));
  size_t workspace_size1 = 0, workspace_size2 = 0;
  // ask bufferSize1 bytes for external memory
  CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
77
78
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
79
80
81
82
83
      NULL));
  void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
  // inspect the matrices A and B to understand the memory requiremnent
  // for the next step
  CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
84
85
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
86
87
      workspace1));
  // ask bufferSize2 bytes for external memory
88
89
90
  CUSPARSE_CALL(cusparseSpGEMM_compute(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
91
92
93
      NULL));
  void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
  // compute the intermediate product of A * B
94
95
96
  CUSPARSE_CALL(cusparseSpGEMM_compute(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
97
98
99
      workspace2));
  // get matrix C non-zero entries C_nnz1
  int64_t C_num_rows1, C_num_cols1, C_nnz1;
100
101
  CUSPARSE_CALL(
      cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
102
  IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
103
104
  NDArray dC_weights = NDArray::Empty(
      {C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
105
106
107
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  // update matC with the new pointers
108
109
  CUSPARSE_CALL(cusparseCsrSetPointers(
      matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data));
110
  // copy the final products to the matrix C
111
112
113
  CUSPARSE_CALL(cusparseSpGEMM_copy(
      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
      matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc));
114
115
116
117
118
119
120
121

  device->FreeWorkspace(ctx, workspace1);
  device->FreeWorkspace(ctx, workspace2);
  // destroy matrix/vector descriptors
  CUSPARSE_CALL(cusparseSpGEMM_destroyDescr(spgemmDesc));
  CUSPARSE_CALL(cusparseDestroySpMat(matA));
  CUSPARSE_CALL(cusparseDestroySpMat(matB));
  CUSPARSE_CALL(cusparseDestroySpMat(matC));
122
  return {
123
124
125
      CSRMatrix(
          A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
126
      dC_weights};
127
128
}

129
#else  // __CUDACC_VER_MAJOR__ != 11
130

131
132
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
 * versions */
133
134
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
135
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
136
137
138
139
140
141
142
143
144
145
146
147
148
    const NDArray B_weights_array) {
  int nnzC;
  csrgemm2Info_t info = nullptr;
  size_t workspace_size;
  const DType alpha = 1.;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int k = B.num_cols;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
149
  cudaStream_t stream = runtime::getCurrentCUDAStream();
150
151
152
153
154
155
156
  auto idtype = A.indptr->dtype;
  auto dtype = A_weights_array->dtype;
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
157
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
158
159
160
161
162
163
164
165
166
  CUSPARSE_CALL(cusparseSetPointerMode(
      thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST));

  CUSPARSE_CALL(cusparseCreateCsrgemm2Info(&info));

  cusparseMatDescr_t matA, matB, matC, matD;
  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
167
  CUSPARSE_CALL(cusparseCreateMatDescr(&matD));  // needed even if D is null
168

169
170
171
172
173
  CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, info, &workspace_size));
174

175
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
176
  IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
177
178
179
180
181
  CUSPARSE_CALL(CSRGEMM<DType>::nnz(
      thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,
      C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
182
183
184

  IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
  NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
185
186
187
188
189
190
  CUSPARSE_CALL(CSRGEMM<DType>::compute(
      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,
      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
      nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),
      C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));
191
192
193
194
195
196
197
198

  device->FreeWorkspace(ctx, workspace);
  CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matD));

199
  return {
200
201
      CSRMatrix(
          m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
202
      C_weights};
203
204
205
206
207
208
209
}

#endif  // __CUDACC_VER_MAJOR__ == 11
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
210
    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
211
212
213
214
215
216
217
218
219
    NDArray B_weights) {
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  CSRMatrix newA, newB;
  bool cast = false;

  // Cast 64 bit indices to 32 bit.
  if (A.indptr->dtype.bits == 64) {
    newA = CSRMatrix(
220
221
        A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),
        AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
222
    newB = CSRMatrix(
223
224
        B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),
        AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
225
226
227
228
229
    cast = true;
  }

  // Reorder weights if A or B has edge IDs
  NDArray newA_weights, newB_weights;
230
231
  if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);
  if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);
232
233
234
235
236
237
238
239
240

  auto result = cusparse::CusparseSpgemm<DType, int32_t>(
      cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
      cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
241
242
243
244
        CSRMatrix(
            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
            AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),
        result.second};
245
246
247
248
249
  } else {
    return result;
  }
}

250
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
251
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
252
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
253
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
254
255
256
257
258
259
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __nv_bfloat16>(
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __nv_bfloat16>(
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif  // BF16_ENABLED
260
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
261
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
262
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
263
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
264
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
265
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
266
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
267
268
269
270
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);

}  // namespace aten
}  // namespace dgl