csr_mm.cu 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/csr_mm.cu
 * \brief SpSpMM/SpGEMM C APIs and definitions.
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

#if 0   // disabling CUDA 11.0+ implementation for now because of problems on bigger graphs

/*! \brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
    const CSRMatrix& A,
    const NDArray A_weights_array,
    const CSRMatrix& B,
    const NDArray B_weights_array) {
  // We use Spgemm (SpSpMM) to perform following operation:
  // C = A x B, where A, B and C are sparse matrices in csr format.
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const DType alpha = 1.0;
  const DType beta = 0.0;
  auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE;
  // device
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  hipStream_t stream = runtime::getCurrentCUDAStream();
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
  // all one data array
  hipsparseSpMatDescr_t matA, matB, matC;
  IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, A.indptr->ctx);
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  constexpr auto idtype = cusparse_idtype<IdType>::value;
  constexpr auto dtype = cuda_dtype<DType>::value;
  // Create sparse matrix A, B and C in CSR format
  CUSPARSE_CALL(hipsparseCreateCsr(&matA,
      A.num_rows, A.num_cols, nnzA,
      A.indptr.Ptr<DType>(),
      A.indices.Ptr<DType>(),
      const_cast<DType*>(A_weights),    // hipsparseCreateCsr only accepts non-const pointers
      idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO, dtype));
  CUSPARSE_CALL(hipsparseCreateCsr(&matB,
      B.num_rows, B.num_cols, nnzB,
      B.indptr.Ptr<DType>(),
      B.indices.Ptr<DType>(),
      const_cast<DType*>(B_weights),    // hipsparseCreateCsr only accepts non-const pointers
      idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO, dtype));
  CUSPARSE_CALL(hipsparseCreateCsr(&matC,
      A.num_rows, B.num_cols, 0,
      nullptr, nullptr, nullptr, idtype, idtype,
      HIPSPARSE_INDEX_BASE_ZERO, dtype));
  // SpGEMM Computation
  hipsparseSpGEMMDescr_t spgemmDesc;
  CUSPARSE_CALL(hipsparseSpGEMM_createDescr(&spgemmDesc));
  size_t workspace_size1 = 0, workspace_size2 = 0;
  // ask bufferSize1 bytes for external memory
  CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
      thr_entry->cusparse_handle, transA, transB,
      &alpha, matA, matB, &beta, matC, dtype,
      HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
      NULL));
  void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
  // inspect the matrices A and B to understand the memory requiremnent
  // for the next step
  CUSPARSE_CALL(hipsparseSpGEMM_workEstimation(
      thr_entry->cusparse_handle, transA, transB,
      &alpha, matA, matB, &beta, matC, dtype,
      HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
      workspace1));
  // ask bufferSize2 bytes for external memory
  CUSPARSE_CALL(hipsparseSpGEMM_compute(thr_entry->cusparse_handle,
      transA, transB, &alpha, matA, matB, &beta, matC,
      dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
      NULL));
  void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
  // compute the intermediate product of A * B
  CUSPARSE_CALL(hipsparseSpGEMM_compute(thr_entry->cusparse_handle,
      transA, transB, &alpha, matA, matB, &beta, matC,
      dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
      workspace2));
  // get matrix C non-zero entries C_nnz1
  int64_t C_num_rows1, C_num_cols1, C_nnz1;
  CUSPARSE_CALL(hipsparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
  IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
  NDArray dC_weights = NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  // update matC with the new pointers
  CUSPARSE_CALL(hipsparseCsrSetPointers(matC, dC_csrOffsets_data,
     dC_columns_data, dC_weights_data));
  // copy the final products to the matrix C
  CUSPARSE_CALL(hipsparseSpGEMM_copy(thr_entry->cusparse_handle,
      transA, transB, &alpha, matA, matB, &beta, matC,
      dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc));

  device->FreeWorkspace(ctx, workspace1);
  device->FreeWorkspace(ctx, workspace2);
  // destroy matrix/vector descriptors
  CUSPARSE_CALL(hipsparseSpGEMM_destroyDescr(spgemmDesc));
  CUSPARSE_CALL(hipsparseDestroySpMat(matA));
  CUSPARSE_CALL(hipsparseDestroySpMat(matB));
  CUSPARSE_CALL(hipsparseDestroySpMat(matC));
  return {
      CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
                NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
      dC_weights};
}

#else   // __CUDACC_VER_MAJOR__ != 11

/*! \brief Cusparse implementation of SpGEMM on Csr format for older CUDA versions */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
    const CSRMatrix& A,
    const NDArray A_weights_array,
    const CSRMatrix& B,
    const NDArray B_weights_array) {
  int nnzC;
  csrgemm2Info_t info = nullptr;
  size_t workspace_size;
  const DType alpha = 1.;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int k = B.num_cols;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto idtype = A.indptr->dtype;
  auto dtype = A_weights_array->dtype;
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
  CUSPARSE_CALL(hipsparseSetPointerMode(
      thr_entry->cusparse_handle, HIPSPARSE_POINTER_MODE_HOST));

  CUSPARSE_CALL(hipsparseCreateCsrgemm2Info(&info));

  hipsparseMatDescr_t matA, matB, matC, matD;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matA));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matB));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matC));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matD));   // needed even if D is null

  CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(thr_entry->cusparse_handle,
      m, n, k, &alpha,
      matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
      matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
      nullptr,
      matD, 0, nullptr, nullptr,
      info,
      &workspace_size));

  void *workspace = device->AllocWorkspace(ctx, workspace_size);
  IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
  CUSPARSE_CALL(CSRGEMM<DType>::nnz(thr_entry->cusparse_handle,
      m, n, k,
      matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
      matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
      matD, 0, nullptr, nullptr,
      matC, C_indptr.Ptr<IdType>(), &nnzC, info, workspace));

  IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
  NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
  CUSPARSE_CALL(CSRGEMM<DType>::compute(thr_entry->cusparse_handle,
      m, n, k, &alpha,
      matA, nnzA, A_weights, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
      matB, nnzB, B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
      nullptr,
      matD, 0, nullptr, nullptr, nullptr,
      matC, C_weights.Ptr<DType>(), C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(),
      info, workspace));

  device->FreeWorkspace(ctx, workspace);
  CUSPARSE_CALL(hipsparseDestroyCsrgemm2Info(info));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matA));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matB));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matC));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matD));

  return {
      CSRMatrix(m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
      C_weights};
}

#endif  // __CUDACC_VER_MAJOR__ == 11
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
    const CSRMatrix& A,
    NDArray A_weights,
    const CSRMatrix& B,
    NDArray B_weights) {
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  CSRMatrix newA, newB;
  bool cast = false;

  // Cast 64 bit indices to 32 bit.
  if (A.indptr->dtype.bits == 64) {
    newA = CSRMatrix(
        A.num_rows, A.num_cols,
        AsNumBits(A.indptr, 32), AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
    newB = CSRMatrix(
        B.num_rows, B.num_cols,
        AsNumBits(B.indptr, 32), AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
    cast = true;
  }

  // Reorder weights if A or B has edge IDs
  NDArray newA_weights, newB_weights;
  if (CSRHasData(A))
    newA_weights = IndexSelect(A_weights, A.data);
  if (CSRHasData(B))
    newB_weights = IndexSelect(B_weights, B.data);

  auto result = cusparse::CusparseSpgemm<DType, int32_t>(
      cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
      cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
      CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64),
                AsNumBits(C.data, 64)),
      result.second};
  } else {
    return result;
  }
}

#ifdef USE_FP16
lisj's avatar
lisj committed
259
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, __half>(
260
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
lisj's avatar
lisj committed
261
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, __half>(
262
263
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif
lisj's avatar
lisj committed
264
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, float>(
265
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
lisj's avatar
lisj committed
266
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, float>(
267
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
lisj's avatar
lisj committed
268
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int32_t, double>(
269
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
lisj's avatar
lisj committed
270
template std::pair<CSRMatrix, NDArray> CSRMM<kDLROCM, int64_t, double>(
271
272
273
274
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);

}  // namespace aten
}  // namespace dgl