scatter_cpu.cpp 2.58 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "scatter_cpu.h"

rusty1s's avatar
update  
rusty1s committed
3
4
5
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
rusty1s's avatar
rusty1s committed
6
7
8
9
10

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size, std::string reduce) {
rusty1s's avatar
update  
rusty1s committed
11
12
13
14
  CHECK_CPU(src);
  CHECK_CPU(index);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());
rusty1s's avatar
rusty1s committed
15

rusty1s's avatar
update  
rusty1s committed
16
17
18
  CHECK_INPUT(src.dim() == index.dim());
  for (auto i = 0; i < index.dim() - 1; i++)
    CHECK_INPUT(src.size(i) >= index.size(i));
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
update  
rusty1s committed
20
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
update  
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
    for (auto i = 0; i < out.dim(); i++)
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
  } else {
    auto sizes = src.sizes().vec();
    if (dim_size.has_value())
      sizes[dim] = dim_size.value();
    else if (index.numel() == 0)
      sizes[dim] = 0;
    else
      sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
    out = torch::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
update  
rusty1s committed
39
40
41
42
43
44
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = torch::full_like(out, src.size(dim), index.options());
    arg_out_data = arg_out.value().data_ptr<int64_t>();
  }
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
update  
rusty1s committed
46
47
  if (index.numel() == 0)
    return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
48

rusty1s's avatar
update  
rusty1s committed
49
50
51
52
53
54
  auto B = 1;
  for (auto i = 0; i < dim; i++)
    B *= src.size(i);
  auto E = src.size(dim);
  auto K = src.numel() / (B * E);
  auto N = out.size(dim);
rusty1s's avatar
rusty1s committed
55

rusty1s's avatar
update  
rusty1s committed
56
57
58
59
  auto index_info = getTensorInfo<int64_t>(index);
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
update  
rusty1s committed
61
62
63
64
    int64_t i, idx;
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (!optional_out.has_value())
        out.fill_(Reducer<scalar_t>::init(REDUCE));
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
update  
rusty1s committed
66
67
68
69
70
71
72
73
74
75
76
      for (auto b = 0; b < B; b++) {
        for (auto e = 0; e < E; e++) {
          for (auto k = 0; k < K; k++) {
            i = b * E * K + e * K + k;
            idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
            Reducer<scalar_t>::update(
                REDUCE, out_data + b * N * K + idx * K + k, src_data[i],
                arg_out_data + b * N * K + idx * K + k, e);
          }
        }
      }
rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
update  
rusty1s committed
78
79
80
81
      if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
        out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0);
    });
  });
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
update  
rusty1s committed
83
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
84
}