scatter.cpp 1.99 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
7
                      int64_t dim);
rusty1s's avatar
rusty1s committed
8
void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
9
                      int64_t dim);
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
                      torch::Tensor arg, int64_t dim);
void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
                      torch::Tensor arg, int64_t dim);
void index_backward_cuda(torch::Tensor grad, torch::Tensor index,
                         torch::Tensor arg, torch::Tensor out, int64_t dim);
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
                 int64_t dim) {
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  CHECK_CUDA(out);
  scatter_mul_cuda(src, index, out, dim);
}

rusty1s's avatar
rusty1s committed
25
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
                 int64_t dim) {
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  CHECK_CUDA(out);
  scatter_div_cuda(src, index, out, dim);
}

rusty1s's avatar
rusty1s committed
33
34
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
                 torch::Tensor arg, int64_t dim) {
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  CHECK_CUDA(out);
  CHECK_CUDA(arg);
  scatter_max_cuda(src, index, out, arg, dim);
}

rusty1s's avatar
rusty1s committed
42
43
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
                 torch::Tensor arg, int64_t dim) {
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  CHECK_CUDA(out);
  CHECK_CUDA(arg);
  scatter_min_cuda(src, index, out, arg, dim);
}

rusty1s's avatar
rusty1s committed
51
52
53
54
55
static auto registry =
    torch::RegisterOperators("torch_scatter_cuda::scatter_mul", &scatter_mul)
        .op("torch_scatter_cuda::scatter_div", &scatter_div)
        .op("torch_scatter_cuda::scatter_max", &scatter_max)
        .op("torch_scatter_cuda::scatter_min", &scatter_min);