scatter.cpp 3.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
2
3
4
5
6

#include "dim_apply.h"

void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
                 int64_t dim) {
rusty1s's avatar
reset  
rusty1s committed
7
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
8
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
rusty1s's avatar
reset  
rusty1s committed
9
10
11
12
13
14
15
    DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
      for (i = 0; i < elems_per_row; i++) {
        idx = index_data[i * index_stride];
        out_data[idx * out_stride] *= src_data[i * src_stride];
      }
    });
  });
rusty1s's avatar
rusty1s committed
16
17
18
19
20
}

void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
                 int64_t dim) {
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
21
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
rusty1s's avatar
rusty1s committed
22
23
24
25
26
    DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
      for (i = 0; i < elems_per_row; i++) {
        idx = index_data[i * index_stride];
        out_data[idx * out_stride] /= src_data[i * src_stride];
      }
rusty1s's avatar
rusty1s committed
27
    });
rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
  });
}

void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
                 at::Tensor arg, int64_t dim) {
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
34
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
    DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
               {
                 for (i = 0; i < elems_per_row; i++) {
                   idx = index_data[i * index_stride];
                   if (src_data[i * src_stride] >= out_data[idx * out_stride]) {
                     out_data[idx * out_stride] = src_data[i * src_stride];
                     arg_data[idx * arg_stride] = i;
                   }
                 }
rusty1s's avatar
rusty1s committed
44
               });
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
  });
}

void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
                 at::Tensor arg, int64_t dim) {
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
51
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
    DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
               {
                 for (i = 0; i < elems_per_row; i++) {
                   idx = index_data[i * index_stride];
                   if (src_data[i * src_stride] <= out_data[idx * out_stride]) {
                     out_data[idx * out_stride] = src_data[i * src_stride];
                     arg_data[idx * arg_stride] = i;
                   }
                 }
rusty1s's avatar
rusty1s committed
61
               });
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
  });
}

void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
                    at::Tensor out, int64_t dim) {
  int64_t elems_per_row = index.size(dim), i, idx;
rusty1s's avatar
rusty1s committed
68
  AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward", [&] {
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
    DIM_APPLY4(scalar_t, grad, int64_t, index, int64_t, arg, scalar_t, out, dim,
               {
                 for (i = 0; i < elems_per_row; i++) {
                   idx = index_data[i * index_stride];
                   if (arg_data[idx * arg_stride] == i) {
                     out_data[i * out_stride] = grad_data[idx * grad_stride];
                   }
                 }
rusty1s's avatar
rusty1s committed
77
               });
rusty1s's avatar
rusty1s committed
78
79
80
81
82
83
84
85
86
87
  });
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("scatter_mul", &scatter_mul, "Scatter Mul (CPU)");
  m.def("scatter_div", &scatter_div, "Scatter Div (CPU)");
  m.def("scatter_max", &scatter_max, "Scatter Max (CPU)");
  m.def("scatter_min", &scatter_min, "Scatter Min (CPU)");
  m.def("index_backward", &index_backward, "Index Backward (CPU)");
}