spspmm_kernel.cu 6.58 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
#include <ATen/ATen.h>
#include <cusparse.h>

rusty1s's avatar
rusty1s committed
4
5
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
6
7
8
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

rusty1s's avatar
rusty1s committed
9
10
#define CSRGEMM(TYPE, ...)                                                     \
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
11
12
13
14
    const auto &the_type = TYPE;                                               \
    (void)the_type;                                                            \
    at::ScalarType _st = ::detail::scalar_type(TYPE);                          \
    switch (_st) {                                                             \
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23
    case at::ScalarType::Float: {                                              \
      using scalar_t = float;                                                  \
      return cusparseScsrgemm(__VA_ARGS__);                                    \
    }                                                                          \
    case at::ScalarType::Double: {                                             \
      using scalar_t = double;                                                 \
      return cusparseDcsrgemm(__VA_ARGS__);                                    \
    }                                                                          \
    default:                                                                   \
rusty1s's avatar
rusty1s committed
24
      AT_ERROR("Not implemented for '", toString(_st), "'");                   \
rusty1s's avatar
rusty1s committed
25
26
27
    }                                                                          \
  }()

rusty1s's avatar
to csr  
rusty1s committed
28
29
30
31
32
33
34
35
static cusparseHandle_t cusparse_handle = 0;

static void init_cusparse() {
  if (cusparse_handle == 0) {
    cusparseStatus_t status = cusparseCreate(&cusparse_handle);
  }
}

rusty1s's avatar
rusty1s committed
36
37
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
rusty1s's avatar
rusty1s committed
38
            at::Tensor valueB, size_t m, size_t k, size_t n) {
rusty1s's avatar
rusty1s committed
39
  cudaSetDevice(indexA.get_device());
rusty1s's avatar
rusty1s committed
40
41
  init_cusparse();

rusty1s's avatar
rusty1s committed
42
43
44
45
46
  indexA = indexA.contiguous();
  valueA = valueA.contiguous();
  indexB = indexB.contiguous();
  valueB = valueB.contiguous();

rusty1s's avatar
rusty1s committed
47
48
  auto nnzA = valueA.size(0);
  auto nnzB = valueB.size(0);
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
  indexA = indexA.toType(at::kInt);
  indexB = indexB.toType(at::kInt);
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
  // Convert A to CSR format.
rusty1s's avatar
rusty1s committed
54
  auto row_ptrA = at::empty(m + 1, indexA.options());
rusty1s's avatar
rusty1s committed
55
56
  cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
                   row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
rusty1s's avatar
rusty1s committed
57
  auto colA = indexA[1];
rusty1s's avatar
rusty1s committed
58
  cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
rusty1s's avatar
rusty1s committed
59
             cudaMemcpyHostToDevice);
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
  // Convert B to CSR format.
rusty1s's avatar
rusty1s committed
62
  auto row_ptrB = at::empty(k + 1, indexB.options());
rusty1s's avatar
rusty1s committed
63
64
  cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
                   row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
rusty1s's avatar
rusty1s committed
65
  auto colB = indexB[1];
rusty1s's avatar
rusty1s committed
66
  cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
rusty1s's avatar
rusty1s committed
67
             cudaMemcpyHostToDevice);
rusty1s's avatar
rusty1s committed
68
69
70
71
72
73
74

  cusparseMatDescr_t descr = 0;
  cusparseCreateMatDescr(&descr);
  cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
  cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);

  int nnzC;
rusty1s's avatar
rusty1s committed
75
  auto row_ptrC = at::empty(m + 1, indexB.options());
rusty1s's avatar
rusty1s committed
76
77
  cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
                      CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
rusty1s's avatar
rusty1s committed
78
79
80
                      row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr,
                      nnzB, row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
                      descr, row_ptrC.DATA_PTR<int>(), &nnzC);
rusty1s's avatar
rusty1s committed
81
82
  auto colC = at::empty(nnzC, indexA.options());
  auto valueC = at::empty(nnzC, valueA.options());
rusty1s's avatar
rusty1s committed
83

rusty1s's avatar
rusty1s committed
84
85
  CSRGEMM(valueC.scalar_type(), cusparse_handle,
          CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
rusty1s's avatar
rusty1s committed
86
87
88
89
90
          n, k, descr, nnzA, valueA.DATA_PTR<scalar_t>(),
          row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr, nnzB,
          valueB.DATA_PTR<scalar_t>(), row_ptrB.DATA_PTR<int>(),
          colB.DATA_PTR<int>(), descr, valueC.DATA_PTR<scalar_t>(),
          row_ptrC.DATA_PTR<int>(), colC.DATA_PTR<int>());
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
  auto rowC = at::empty(nnzC, indexA.options());
rusty1s's avatar
rusty1s committed
93
94
  cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
                   rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
rusty1s's avatar
rusty1s committed
95
96
97
98

  auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);

  return std::make_tuple(indexC, valueC);
rusty1s's avatar
to csr  
rusty1s committed
99
}
rusty1s's avatar
rusty1s committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

at::Tensor degree(at::Tensor row, int64_t num_nodes) {
  auto zero = at::zeros(num_nodes, row.options());
  auto one = at::ones(row.size(0), row.options());
  return zero.scatter_add_(0, row, one);
}

std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
                                          int64_t num_nodes) {
  // Assert already coalesced input.
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col);
}

template <typename scalar_t>
__global__ void spspmm_bw_kernel(
    const int64_t *__restrict__ index, scalar_t *__restrict__ value,
    const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
    const scalar_t *__restrict__ valueA, const int64_t *__restrict__ rowB,
    const int64_t *__restrict__ colB, const scalar_t *__restrict__ valueB,
    const size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t e = idx; e < numel; e += stride) {
    int64_t i = index[e], j = index[numel + e];

    for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
      int64_t cA = colA[dA];

      for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
        int64_t cB = colB[dB];

        if (cA == cB) {
          value[e] += valueA[dA] * valueB[dB];
        }

        if (cB >= cA) {
          break;
        }
      }
    }
  }
}

at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
                          at::Tensor valueA, at::Tensor indexB,
                          at::Tensor valueB, size_t rowA_max, size_t rowB_max) {
  cudaSetDevice(index.get_device());
  auto value = at::zeros(index.size(1), valueA.options());

  at::Tensor rowA, colA;
  std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);

  at::Tensor rowB, colB;
  std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);

rusty1s's avatar
rusty1s committed
157
  AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
rusty1s's avatar
rusty1s committed
158
    spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
159
160
161
162
        index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
        rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
        valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
        colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(), value.numel());
rusty1s's avatar
rusty1s committed
163
164
165
166
  });

  return value;
}