spspmm_kernel.cu 6.23 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
4
#include <ATen/ATen.h>

#include <cusparse.h>

rusty1s's avatar
rusty1s committed
5
6
7
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#define CSRGEMM(TYPE, ...)                                                     \
  [&] {                                                                        \
    const at::Type &the_type = TYPE;                                           \
    switch (the_type.scalarType()) {                                           \
    case at::ScalarType::Float: {                                              \
      using scalar_t = float;                                                  \
      return cusparseScsrgemm(__VA_ARGS__);                                    \
    }                                                                          \
    case at::ScalarType::Double: {                                             \
      using scalar_t = double;                                                 \
      return cusparseDcsrgemm(__VA_ARGS__);                                    \
    }                                                                          \
    default:                                                                   \
      AT_ERROR("Not implemented for '%s'", the_type.toString());               \
    }                                                                          \
  }()

rusty1s's avatar
to csr  
rusty1s committed
25
26
27
28
29
30
31
32
static cusparseHandle_t cusparse_handle = 0;

static void init_cusparse() {
  if (cusparse_handle == 0) {
    cusparseStatus_t status = cusparseCreate(&cusparse_handle);
  }
}

rusty1s's avatar
rusty1s committed
33
34
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
rusty1s's avatar
rusty1s committed
35
            at::Tensor valueB, size_t m, size_t k, size_t n) {
rusty1s's avatar
rusty1s committed
36
  cudaSetDevice(indexA.get_device());
rusty1s's avatar
rusty1s committed
37
38
  init_cusparse();

rusty1s's avatar
rusty1s committed
39
40
41
42
43
  indexA = indexA.contiguous();
  valueA = valueA.contiguous();
  indexB = indexB.contiguous();
  valueB = valueB.contiguous();

rusty1s's avatar
rusty1s committed
44
45
  auto nnzA = valueA.size(0);
  auto nnzB = valueB.size(0);
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
  indexA = indexA.toType(at::kInt);
  indexB = indexB.toType(at::kInt);
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
  // Convert A to CSR format.
  auto row_ptrA = at::empty(m + 1, indexA.type());
rusty1s's avatar
rusty1s committed
52
53
54
  cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
                   row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
  auto colA = indexA[1];
rusty1s's avatar
rusty1s committed
55
56
  cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int),
             cudaMemcpyHostToDevice);
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
  // Convert B to CSR format.
  auto row_ptrB = at::empty(k + 1, indexB.type());
rusty1s's avatar
rusty1s committed
60
61
62
  cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
                   row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
  auto colB = indexB[1];
rusty1s's avatar
rusty1s committed
63
64
  cudaMemcpy(row_ptrB.data<int>() + k, &nnzB, sizeof(int),
             cudaMemcpyHostToDevice);
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71

  cusparseMatDescr_t descr = 0;
  cusparseCreateMatDescr(&descr);
  cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
  cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);

  int nnzC;
rusty1s's avatar
rusty1s committed
72
  auto row_ptrC = at::empty(m + 1, indexB.type());
rusty1s's avatar
rusty1s committed
73
74
75
76
77
  cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
                      CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
                      row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
                      row_ptrB.data<int>(), colB.data<int>(), descr,
                      row_ptrC.data<int>(), &nnzC);
rusty1s's avatar
rusty1s committed
78
79
  auto colC = at::empty(nnzC, indexA.type());
  auto valueC = at::empty(nnzC, valueA.type());
rusty1s's avatar
rusty1s committed
80
81
82
83
84
85
86
87

  CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
          CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
          valueA.data<scalar_t>(), row_ptrA.data<int>(), colA.data<int>(),
          descr, nnzB, valueB.data<scalar_t>(), row_ptrB.data<int>(),
          colB.data<int>(), descr, valueC.data<scalar_t>(),
          row_ptrC.data<int>(), colC.data<int>());

rusty1s's avatar
rusty1s committed
88
  auto rowC = at::empty(nnzC, indexA.type());
rusty1s's avatar
rusty1s committed
89
90
91
92
93
94
  cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
                   rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);

  auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);

  return std::make_tuple(indexC, valueC);
rusty1s's avatar
to csr  
rusty1s committed
95
}
rusty1s's avatar
rusty1s committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

at::Tensor degree(at::Tensor row, int64_t num_nodes) {
  auto zero = at::zeros(num_nodes, row.options());
  auto one = at::ones(row.size(0), row.options());
  return zero.scatter_add_(0, row, one);
}

std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
                                          int64_t num_nodes) {
  // Assert already coalesced input.
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col);
}

template <typename scalar_t>
__global__ void spspmm_bw_kernel(
    const int64_t *__restrict__ index, scalar_t *__restrict__ value,
    const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
    const scalar_t *__restrict__ valueA, const int64_t *__restrict__ rowB,
    const int64_t *__restrict__ colB, const scalar_t *__restrict__ valueB,
    const size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t e = idx; e < numel; e += stride) {
    int64_t i = index[e], j = index[numel + e];

    for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
      int64_t cA = colA[dA];

      for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
        int64_t cB = colB[dB];

        if (cA == cB) {
          value[e] += valueA[dA] * valueB[dB];
        }

        if (cB >= cA) {
          break;
        }
      }
    }
  }
}

at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
                          at::Tensor valueA, at::Tensor indexB,
                          at::Tensor valueB, size_t rowA_max, size_t rowB_max) {
  cudaSetDevice(index.get_device());
  auto value = at::zeros(index.size(1), valueA.options());

  at::Tensor rowA, colA;
  std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);

  at::Tensor rowB, colB;
  std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);

  AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
    spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
        index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(),
        colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(),
        colB.data<int64_t>(), valueB.data<scalar_t>(), value.numel());
  });

  return value;
}