spspmm.cpp 1.47 KB
Newer Older
1
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
6
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
rusty1s's avatar
rusty1s committed
7
8
9
10
            at::Tensor valueB, size_t m, size_t k, size_t n);
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
                          at::Tensor valueA, at::Tensor indexB,
                          at::Tensor valueB, size_t rowA_max, size_t rowB_max);
rusty1s's avatar
to csr  
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
13
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
                                          at::Tensor indexB, at::Tensor valueB,
rusty1s's avatar
rusty1s committed
14
                                          size_t m, size_t k, size_t n) {
rusty1s's avatar
rusty1s committed
15
16
17
18
19
  CHECK_CUDA(indexA);
  CHECK_CUDA(valueA);
  CHECK_CUDA(indexB);
  CHECK_CUDA(valueB);
  return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
rusty1s's avatar
rusty1s committed
20
21
}

rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
                     at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
                     size_t rowB_max) {
  CHECK_CUDA(index);
  CHECK_CUDA(indexA);
  CHECK_CUDA(valueA);
  CHECK_CUDA(indexB);
  CHECK_CUDA(valueB);
  return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
                        rowB_max);
}

rusty1s's avatar
rusty1s committed
34
35
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
rusty1s's avatar
rusty1s committed
36
37
  m.def("spspmm_bw", &spspmm_bw,
        "Sparse-Sparse Matrix Multiplication Backward (CUDA)");
rusty1s's avatar
rusty1s committed
38
}