gather.cpp 1.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
9
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
                              torch::optional<torch::Tensor> out_opt);
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
                              torch::optional<torch::Tensor> out_opt);
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
                         torch::optional<torch::Tensor> out_opt) {
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
  CHECK_CUDA(src);
  CHECK_CUDA(indptr);
  if (out_opt.has_value())
    CHECK_CUDA(out_opt.value());
  return gather_csr_cuda(src, indptr, out_opt);
}

rusty1s's avatar
rusty1s committed
20
21
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
                         torch::optional<torch::Tensor> out_opt) {
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  if (out_opt.has_value())
    CHECK_CUDA(out_opt.value());
  return gather_coo_cuda(src, index, out_opt);
}

rusty1s's avatar
rusty1s committed
29
30
31
static auto registry =
    torch::RegisterOperators("torch_scatter_cuda::gather_csr", &gather_csr)
        .op("torch_scatter_cuda::gather_coo", &gather_coo);