csr_sum.cu 6.77 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmm.cu
 * @brief SpGEAM C APIs and definitions.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

/*! Cusparse implementation of SpSum on Csr format. */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
    const CSRMatrix& A,
    const NDArray A_weights_array,
    const CSRMatrix& B,
    const NDArray B_weights_array) {
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  int nnzC;
  const DType alpha = 1.0;
  const DType beta = 1.0;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
36
  cudaStream_t stream = runtime::getCurrentCUDAStream();
37
38
39
40
41
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle)
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
42
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
43
44
45
46
47
48
49
50
51

  cusparseMatDescr_t matA, matB, matC;
  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));

  cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
  size_t workspace_size = 0;
  /* prepare output C */
52
  IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  IdArray dC_columns;
  NDArray dC_weights;
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  /* prepare buffer */
  CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(
      thr_entry->cusparse_handle, m, n, &alpha,
      matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(),
      &beta, matB, nnzB, B_weights,
      B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(),
      matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
      &workspace_size));

  void *workspace = device->AllocWorkspace(ctx, workspace_size);
  CUSPARSE_CALL(CSRGEAM<DType>::nnz(thr_entry->cusparse_handle,
      m, n, matA, nnzA,
      A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(),
      matB, nnzB,
      B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(),
      matC, dC_csrOffsets_data, &nnzC, workspace));

  dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);
  dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);
  dC_columns_data = dC_columns.Ptr<IdType>();
  dC_weights_data = dC_weights.Ptr<DType>();

  CUSPARSE_CALL(CSRGEAM<DType>::compute(
      thr_entry->cusparse_handle, m, n, &alpha,
      matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(),
      &beta, matB, nnzB, B_weights,
      B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(),
      matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
      workspace));

  device->FreeWorkspace(ctx, workspace);
  // destroy matrix/vector descriptors
  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
101
102
103
  return {
    CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
              NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    dC_weights};
}
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
    const std::vector<CSRMatrix>& As,
    const std::vector<NDArray>& A_weights) {
  const int64_t M = As[0].num_rows;
  const int64_t N = As[0].num_cols;
  const int64_t n = As.size();

  // Cast 64 bit indices to 32 bit
  std::vector<CSRMatrix> newAs;
118
  newAs.reserve(n);
119
120
121
122
123
124
125
  bool cast = false;
  if (As[0].indptr->dtype.bits == 64) {
    for (int i = 0; i < n; ++i)
      newAs.emplace_back(
        As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
        AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
    cast = true;
126
127
128
129
130
131
132
133
134
135
  } else {
    for (int i = 0; i < n; ++i)
      newAs.push_back(As[i]);
  }

  // cuSPARSE csrgeam2 requires the CSR to be sorted.
  // TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure how to do it.
  for (int i = 0; i < n; ++i) {
    if (!newAs[i].sorted)
      newAs[i] = CSRSort(newAs[i]);
136
137
138
139
140
  }

  // Reorder weights if A[i] has edge IDs
  std::vector<NDArray> A_weights_reordered(n);
  for (int i = 0; i < n; ++i) {
141
142
    if (CSRHasData(newAs[i]))
      A_weights_reordered[i] = IndexSelect(A_weights[i], newAs[i].data);
143
144
145
146
147
148
149
    else
      A_weights_reordered[i] = A_weights[i];
  }

  // Loop and sum
  auto result = std::make_pair(
      CSRMatrix(
150
151
152
        newAs[0].num_rows, newAs[0].num_cols,
        newAs[0].indptr, newAs[0].indices,
        NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
153
154
155
      A_weights_reordered[0]);  // Weights already reordered so we don't need As[0].data
  for (int64_t i = 1; i < n; ++i)
    result = cusparse::CusparseCsrgeam2<DType, int32_t>(
156
        result.first, result.second, newAs[i], A_weights_reordered[i]);
157
158
159
160
161

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
162
163
      CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64),
                AsNumBits(C.data, 64), true),
164
165
166
167
168
169
      result.second};
  } else {
    return result;
  }
}

170
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
171
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
172
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
173
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
174
175
176
177
178
179
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __nv_bfloat16>(
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __nv_bfloat16>(
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif  // BF16_ENABLED
180
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
181
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
182
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
183
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
184
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
185
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
186
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
187
188
189
190
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);

}  // namespace aten
}  // namespace dgl