spmm.cpp 1.46 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
9
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
          torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
          std::string reduce);
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
13
torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
                               torch::Tensor col, torch::Tensor mat,
                               torch::Tensor grad, std::string reduce);
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col,
     torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
     std::string reduce) {
rusty1s's avatar
rusty1s committed
19
20
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
rusty1s's avatar
rusty1s committed
21
22
  if (value_opt.has_value())
    CHECK_CUDA(value_opt.value());
rusty1s's avatar
rusty1s committed
23
  CHECK_CUDA(mat);
rusty1s's avatar
rusty1s committed
24
  return spmm_cuda(rowptr, col, value_opt, mat, reduce);
rusty1s's avatar
rusty1s committed
25
26
}

rusty1s's avatar
rusty1s committed
27
28
29
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
                          torch::Tensor col, torch::Tensor mat,
                          torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
30
  CHECK_CUDA(row);
rusty1s's avatar
rusty1s committed
31
  CHECK_CUDA(rowptr);
rusty1s's avatar
rusty1s committed
32
  CHECK_CUDA(col);
rusty1s's avatar
rusty1s committed
33
34
  CHECK_CUDA(mat);
  CHECK_CUDA(grad);
rusty1s's avatar
rusty1s committed
35
  return spmm_val_bw_cuda(row, rowptr, col, mat, grad, reduce);
rusty1s's avatar
rusty1s committed
36
37
}

rusty1s's avatar
rusty1s committed
38
39
40
static auto registry =
    torch::RegisterOperators("torch_sparse_cuda::spmm", &spmm)
        .op("torch_sparse_cuda::spmm_val_bw", &spmm_val_bw);