scatter.cpp 2.97 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2
3
4

#include "dim_apply.h"

rusty1s's avatar
rusty1s committed
5
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
8
                 int64_t dim) {
rusty1s's avatar
rusty1s committed
9
10
11
  CHECK_CPU(src);
  CHECK_CPU(index);
  CHECK_CPU(out);
rusty1s's avatar
reset  
rusty1s committed
12
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
13
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
rusty1s's avatar
reset  
rusty1s committed
14
15
16
17
18
19
20
    DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
      for (i = 0; i < elems_per_row; i++) {
        idx = index_data[i * index_stride];
        out_data[idx * out_stride] *= src_data[i * src_stride];
      }
    });
  });
rusty1s's avatar
rusty1s committed
21
22
}

rusty1s's avatar
rusty1s committed
23
void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
24
                 int64_t dim) {
rusty1s's avatar
rusty1s committed
25
26
27
  CHECK_CPU(src);
  CHECK_CPU(index);
  CHECK_CPU(out);
rusty1s's avatar
rusty1s committed
28
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
29
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
rusty1s's avatar
rusty1s committed
30
31
32
33
34
    DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
      for (i = 0; i < elems_per_row; i++) {
        idx = index_data[i * index_stride];
        out_data[idx * out_stride] /= src_data[i * src_stride];
      }
rusty1s's avatar
rusty1s committed
35
    });
rusty1s's avatar
rusty1s committed
36
37
38
  });
}

rusty1s's avatar
rusty1s committed
39
40
void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
                 torch::Tensor arg, int64_t dim) {
rusty1s's avatar
rusty1s committed
41
42
43
  CHECK_CPU(src);
  CHECK_CPU(index);
  CHECK_CPU(out);
rusty1s's avatar
rusty1s committed
44
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
45
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
53
54
    DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
               {
                 for (i = 0; i < elems_per_row; i++) {
                   idx = index_data[i * index_stride];
                   if (src_data[i * src_stride] >= out_data[idx * out_stride]) {
                     out_data[idx * out_stride] = src_data[i * src_stride];
                     arg_data[idx * arg_stride] = i;
                   }
                 }
rusty1s's avatar
rusty1s committed
55
               });
rusty1s's avatar
rusty1s committed
56
57
58
  });
}

rusty1s's avatar
rusty1s committed
59
60
void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
                 torch::Tensor arg, int64_t dim) {
rusty1s's avatar
rusty1s committed
61
62
63
64
  CHECK_CPU(src);
  CHECK_CPU(index);
  CHECK_CPU(out);
  CHECK_CPU(arg);
rusty1s's avatar
rusty1s committed
65
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
66
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
    DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
               {
                 for (i = 0; i < elems_per_row; i++) {
                   idx = index_data[i * index_stride];
                   if (src_data[i * src_stride] <= out_data[idx * out_stride]) {
                     out_data[idx * out_stride] = src_data[i * src_stride];
                     arg_data[idx * arg_stride] = i;
                   }
                 }
rusty1s's avatar
rusty1s committed
76
               });
rusty1s's avatar
rusty1s committed
77
78
79
  });
}

rusty1s's avatar
rusty1s committed
80
81
82
83
84
static auto registry =
    torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul)
        .op("torch_scatter_cpu::scatter_div", &scatter_div)
        .op("torch_scatter_cpu::scatter_max", &scatter_max)
        .op("torch_scatter_cpu::scatter_min", &scatter_min);