matmul.cpp 441 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include <torch/torch.h>

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")

rusty1s's avatar
to csr  
rusty1s committed
5
6
7
8
9
10
at::Tensor spspmm_cuda(at::Tensor matrix1, at::Tensor matrix2);

at::Tensor spspmm(at::Tensor matrix1, at::Tensor matrix2) {
  CHECK_CUDA(matrix1);
  CHECK_CUDA(matrix2);
  return spspmm_cuda(matrix1, matrix2);
rusty1s's avatar
rusty1s committed
11
12
13
14
15
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
}