segment.cpp 1.29 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
9
10
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
                 torch::optional<torch::Tensor> out_opt, std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
11
                 std::string reduce);
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
14
15
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
            torch::optional<torch::Tensor> out_opt, std::string reduce) {
rusty1s's avatar
rusty1s committed
16
17
  CHECK_CUDA(src);
  CHECK_CUDA(indptr);
18
19
  if (out_opt.has_value())
    CHECK_CUDA(out_opt.value());
rusty1s's avatar
rusty1s committed
20
  return segment_csr_cuda(src, indptr, out_opt, reduce);
rusty1s's avatar
rusty1s committed
21
22
}

rusty1s's avatar
rusty1s committed
23
24
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
25
            std::string reduce) {
rusty1s's avatar
rusty1s committed
26
27
  CHECK_CUDA(src);
  CHECK_CUDA(index);
rusty1s's avatar
rusty1s committed
28
  CHECK_CUDA(out);
rusty1s's avatar
rusty1s committed
29
  return segment_coo_cuda(src, index, out, reduce);
rusty1s's avatar
rusty1s committed
30
31
}

rusty1s's avatar
rusty1s committed
32
33
34
static auto registry =
    torch::RegisterOperators("torch_scatter_cuda::segment_csr", &segment_csr)
        .op("torch_scatter_cuda::segment_coo", &segment_coo);