spmm.cpp 1.39 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2
3
4

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")

rusty1s's avatar
rusty1s committed
5
6
7
8
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
          torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
          std::string reduce);
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
12
torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
                               torch::Tensor col, torch::Tensor mat,
                               torch::Tensor grad, std::string reduce);
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
15
16
17
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col,
     torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
     std::string reduce) {
rusty1s's avatar
rusty1s committed
18
19
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
rusty1s's avatar
rusty1s committed
20
21
  if (value_opt.has_value())
    CHECK_CUDA(value_opt.value());
rusty1s's avatar
rusty1s committed
22
  CHECK_CUDA(mat);
rusty1s's avatar
rusty1s committed
23
  return spmm_cuda(rowptr, col, value_opt, mat, reduce);
rusty1s's avatar
rusty1s committed
24
25
}

rusty1s's avatar
rusty1s committed
26
27
28
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
                          torch::Tensor col, torch::Tensor mat,
                          torch::Tensor grad, std::string reduce) {
rusty1s's avatar
rusty1s committed
29
  CHECK_CUDA(row);
rusty1s's avatar
rusty1s committed
30
  CHECK_CUDA(rowptr);
rusty1s's avatar
rusty1s committed
31
  CHECK_CUDA(col);
rusty1s's avatar
rusty1s committed
32
33
  CHECK_CUDA(mat);
  CHECK_CUDA(grad);
rusty1s's avatar
rusty1s committed
34
  return spmm_val_bw_cuda(row, rowptr, col, mat, grad, reduce);
rusty1s's avatar
rusty1s committed
35
36
}

rusty1s's avatar
rusty1s committed
37
38
39
static auto registry =
    torch::RegisterOperators("torch_sparse_cuda::spmm", &spmm)
        .op("torch_sparse_cuda::spmm_val_bw", &spmm_val_bw);