matmul.cpp 441 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include <torch/torch.h>

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")

rusty1s's avatar
rusty1s committed
5
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B);
rusty1s's avatar
to csr  
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
8
9
10
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor A, at::Tensor B) {
  CHECK_CUDA(A);
  CHECK_CUDA(B);
  return spspmm_cuda(A, B);
rusty1s's avatar
rusty1s committed
11
12
13
14
15
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
}