csr_sum.hip 6.86 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/spmm.cu
 * @brief SpGEAM C APIs and definitions.
7
8
9
 */
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
10

11
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
12
13
#include "cusparse_dispatcher.cuh"
#include "functor.cuh"
14
15
16
17
18
19
20
21

namespace dgl {

using namespace dgl::runtime;

namespace aten {
namespace cusparse {

22
/** Cusparse implementation of SpSum on Csr format. */
23
24
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
25
    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
26
27
28
29
30
31
32
33
34
35
36
    const NDArray B_weights_array) {
  const int m = A.num_rows;
  const int n = A.num_cols;
  const int nnzA = A.indices->shape[0];
  const int nnzB = B.indices->shape[0];
  int nnzC;
  const DType alpha = 1.0;
  const DType beta = 1.0;
  auto ctx = A.indptr->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
sangwzh's avatar
sangwzh committed
37
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
38
39
40
41
  const DType* A_weights = A_weights_array.Ptr<DType>();
  const DType* B_weights = B_weights_array.Ptr<DType>();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle)
sangwzh's avatar
sangwzh committed
42
43
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
44

sangwzh's avatar
sangwzh committed
45
46
47
48
  hipsparseMatDescr_t matA, matB, matC;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matA));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matB));
  CUSPARSE_CALL(hipsparseCreateMatDescr(&matC));
49

sangwzh's avatar
sangwzh committed
50
51
  hipsparseSetPointerMode(
      thr_entry->cusparse_handle, HIPSPARSE_POINTER_MODE_HOST);
52
53
  size_t workspace_size = 0;
  /* prepare output C */
54
  IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);
55
56
57
58
59
60
61
  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
  IdArray dC_columns;
  NDArray dC_weights;
  IdType* dC_columns_data = dC_columns.Ptr<IdType>();
  DType* dC_weights_data = dC_weights.Ptr<DType>();
  /* prepare buffer */
  CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(
62
63
64
65
66
67
68
69
70
71
      thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
      B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
      dC_weights_data, dC_csrOffsets_data, dC_columns_data, &workspace_size));

  void* workspace = device->AllocWorkspace(ctx, workspace_size);
  CUSPARSE_CALL(CSRGEAM<DType>::nnz(
      thr_entry->cusparse_handle, m, n, matA, nnzA, A.indptr.Ptr<IdType>(),
      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
      B.indices.Ptr<IdType>(), matC, dC_csrOffsets_data, &nnzC, workspace));
72
73
74
75
76
77
78

  dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);
  dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);
  dC_columns_data = dC_columns.Ptr<IdType>();
  dC_weights_data = dC_weights.Ptr<DType>();

  CUSPARSE_CALL(CSRGEAM<DType>::compute(
79
80
81
82
      thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
      B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
      dC_weights_data, dC_csrOffsets_data, dC_columns_data, workspace));
83
84
85

  device->FreeWorkspace(ctx, workspace);
  // destroy matrix/vector descriptors
sangwzh's avatar
sangwzh committed
86
87
88
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matA));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matB));
  CUSPARSE_CALL(hipsparseDestroyMatDescr(matC));
89
  return {
90
91
92
93
      CSRMatrix(
          A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
      dC_weights};
94
95
96
97
98
}
}  // namespace cusparse

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
99
    const std::vector<CSRMatrix>& As, const std::vector<NDArray>& A_weights) {
100
101
102
103
104
105
  const int64_t M = As[0].num_rows;
  const int64_t N = As[0].num_cols;
  const int64_t n = As.size();

  // Cast 64 bit indices to 32 bit
  std::vector<CSRMatrix> newAs;
106
  newAs.reserve(n);
107
108
109
110
  bool cast = false;
  if (As[0].indptr->dtype.bits == 64) {
    for (int i = 0; i < n; ++i)
      newAs.emplace_back(
111
112
          As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
          AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
113
    cast = true;
114
  } else {
115
    for (int i = 0; i < n; ++i) newAs.push_back(As[i]);
116
117
118
  }

  // cuSPARSE csrgeam2 requires the CSR to be sorted.
119
120
  // TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure
  // how to do it.
121
  for (int i = 0; i < n; ++i) {
122
    if (!newAs[i].sorted) newAs[i] = CSRSort(newAs[i]);
123
124
125
126
127
  }

  // Reorder weights if A[i] has edge IDs
  std::vector<NDArray> A_weights_reordered(n);
  for (int i = 0; i < n; ++i) {
128
129
    if (CSRHasData(newAs[i]))
      A_weights_reordered[i] = IndexSelect(A_weights[i], newAs[i].data);
130
131
132
133
134
135
136
    else
      A_weights_reordered[i] = A_weights[i];
  }

  // Loop and sum
  auto result = std::make_pair(
      CSRMatrix(
137
138
139
140
141
          newAs[0].num_rows, newAs[0].num_cols, newAs[0].indptr,
          newAs[0].indices,
          NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
      A_weights_reordered[0]);  // Weights already reordered so we don't need
                                // As[0].data
142
143
  for (int64_t i = 1; i < n; ++i)
    result = cusparse::CusparseCsrgeam2<DType, int32_t>(
144
        result.first, result.second, newAs[i], A_weights_reordered[i]);
145
146
147
148
149

  // Cast 32 bit indices back to 64 bit if necessary
  if (cast) {
    CSRMatrix C = result.first;
    return {
150
151
152
153
        CSRMatrix(
            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
            AsNumBits(C.indices, 64), AsNumBits(C.data, 64), true),
        result.second};
154
155
156
157
158
  } else {
    return result;
  }
}

159
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
160
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
161
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
162
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
163
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
164
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __hip_bfloat16>(
165
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
sangwzh's avatar
sangwzh committed
166
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __hip_bfloat16>(
167
168
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif  // BF16_ENABLED
169
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
170
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
171
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
172
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
173
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
174
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
175
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
176
177
178
179
    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);

}  // namespace aten
}  // namespace dgl