csr_sum.cu 6.75 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmm.cu
 * @brief SpGEAM C APIs and definitions.
5
6
7
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
8

9
#include "../../runtime/cuda/cuda_common.h"
10
11
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
12
13
14
15
16
17
18
19

namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

20
/** Cusparse implementation of SpSum on Csr format. */
21
22
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
23
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
24
25
26
27
28
29
30
31
32
33
34
    const NDArray B_weights_array) {
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  int nnzC;
  const DType alpha = 1.0;
  const DType beta = 1.0;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
35
  cudaStream_t stream = runtime::getCurrentCUDAStream();
36
37
38
39
40
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle)
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
41
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
42
43
44
45
46
47

  cusparseMatDescr_t matA, matB, matC;
  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));

48
49
  cusparseSetPointerMode(
      thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
50
51
  size_t workspace_size = 0;
  /* prepare output C */
52
  IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);
53
54
55
56
57
58
59
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  IdArray dC_columns;
  NDArray dC_weights;
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  /* prepare buffer */
  CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(
60
61
62
63
64
65
66
67
68
69
      thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
      B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
      dC_weights_data, dC_csrOffsets_data, dC_columns_data, &workspace_size));

  void* workspace = device->AllocWorkspace(ctx, workspace_size);
  CUSPARSE_CALL(CSRGEAM<DType>::nnz(
      thr_entry->cusparse_handle, m, n, matA, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(), matC, dC_csrOffsets_data, &nnzC, workspace));
70
71
72
73
74
75
76

  dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);
  dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);
  dC_columns_data = dC_columns.Ptr<IdType>();
  dC_weights_data = dC_weights.Ptr<DType>();

  CUSPARSE_CALL(CSRGEAM<DType>::compute(
77
78
79
80
      thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
      B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
      dC_weights_data, dC_csrOffsets_data, dC_columns_data, workspace));
81
82
83
84
85
86

  device->FreeWorkspace(ctx, workspace);
  // destroy matrix/vector descriptors
  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
87
  return {
88
89
90
91
      CSRMatrix(
          A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
      dC_weights};
92
93
94
95
96
}
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
97
    const std::vector<CSRMatrix>& As, const std::vector<NDArray>& A_weights) {
98
99
100
101
102
103
  const int64_t M = As[0].num_rows;
  const int64_t N = As[0].num_cols;
  const int64_t n = As.size();

  // Cast 64 bit indices to 32 bit
  std::vector<CSRMatrix> newAs;
104
  newAs.reserve(n);
105
106
107
108
  bool cast = false;
  if (As[0].indptr->dtype.bits == 64) {
    for (int i = 0; i < n; ++i)
      newAs.emplace_back(
109
110
          As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
          AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
111
    cast = true;
112
  } else {
113
    for (int i = 0; i < n; ++i) newAs.push_back(As[i]);
114
115
116
  }

  // cuSPARSE csrgeam2 requires the CSR to be sorted.
117
118
  // TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure
  // how to do it.
119
  for (int i = 0; i < n; ++i) {
120
    if (!newAs[i].sorted) newAs[i] = CSRSort(newAs[i]);
121
122
123
124
125
  }

  // Reorder weights if A[i] has edge IDs
  std::vector<NDArray> A_weights_reordered(n);
  for (int i = 0; i < n; ++i) {
126
127
    if (CSRHasData(newAs[i]))
      A_weights_reordered[i] = IndexSelect(A_weights[i], newAs[i].data);
128
129
130
131
132
133
134
    else
      A_weights_reordered[i] = A_weights[i];
  }

  // Loop and sum
  auto result = std::make_pair(
      CSRMatrix(
135
136
137
138
139
          newAs[0].num_rows, newAs[0].num_cols, newAs[0].indptr,
          newAs[0].indices,
          NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
      A_weights_reordered[0]);  // Weights already reordered so we don't need
                                // As[0].data
140
141
  for (int64_t i = 1; i < n; ++i)
    result = cusparse::CusparseCsrgeam2<DType, int32_t>(
142
        result.first, result.second, newAs[i], A_weights_reordered[i]);
143
144
145
146
147

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
148
149
150
151
        CSRMatrix(
            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
            AsNumBits(C.indices, 64), AsNumBits(C.data, 64), true),
        result.second};
152
153
154
155
156
  } else {
    return result;
  }
}

157
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
158
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
159
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
160
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
161
162
163
164
165
166
#if BF16_ENABLED
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __nv_bfloat16>(
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __nv_bfloat16>(
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif  // BF16_ENABLED
167
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
168
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
169
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
170
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
171
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
172
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
173
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
174
175
176
177
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);

}  // namespace aten
}  // namespace dgl