scatter_cpu.cpp 2.82 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "scatter_cpu.h"

rusty1s's avatar
rusty1s committed
3
4
5
// #include "index_info.h"
// #include "reducer.h"
// #include "utils.h"
rusty1s's avatar
rusty1s committed
6
7
8
9
10

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size, std::string reduce) {
rusty1s's avatar
rusty1s committed
11
12
13
14
15
  return std::make_tuple(src, optional_out);
  // CHECK_CPU(src);
  // CHECK_CPU(index);
  // if (optional_out.has_value())
  //   CHECK_CPU(optional_out.value());
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
  // CHECK_INPUT(src.dim() == index.dim());
  // for (auto i = 0; i < index.dim() - 1; i++)
  //   CHECK_INPUT(src.size(i) >= index.size(i));
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
  // src = src.contiguous();
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  // torch::Tensor out;
  // if (optional_out.has_value()) {
  //   out = optional_out.value().contiguous();
  //   for (auto i = 0; i < out.dim(); i++)
  //     if (i != dim)
  //       CHECK_INPUT(src.size(i) == out.size(i));
  // } else {
  //   auto sizes = src.sizes().vec();
  //   if (dim_size.has_value())
  //     sizes[dim] = dim_size.value();
  //   else if (index.numel() == 0)
  //     sizes[dim] = 0;
  //   else
  //     sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
  //   out = torch::empty(sizes, src.options());
  // }
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
  // torch::optional<torch::Tensor> arg_out = torch::nullopt;
  // int64_t *arg_out_data = nullptr;
  // if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
  //   arg_out = torch::full_like(out, src.size(dim), index.options());
  //   arg_out_data = arg_out.value().data_ptr<int64_t>();
  // }
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
  // if (index.numel() == 0)
  //   return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
  // auto B = 1;
  // for (auto i = 0; i < dim; i++)
  //   B *= src.size(i);
  // auto E = src.size(dim);
  // auto K = src.numel() / (B * E);
  // auto N = out.size(dim);
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
58
59
60
  // auto index_info = getTensorInfo<int64_t>(index);
  // AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
  //   auto src_data = src.data_ptr<scalar_t>();
  //   auto out_data = out.data_ptr<scalar_t>();
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
rusty1s committed
62
63
64
65
  //   int64_t i, idx;
  //   AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
  //     if (!optional_out.has_value())
  //       out.fill_(Reducer<scalar_t>::init(REDUCE));
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77
  //     for (auto b = 0; b < B; b++) {
  //       for (auto e = 0; e < E; e++) {
  //         for (auto k = 0; k < K; k++) {
  //           i = b * E * K + e * K + k;
  //           idx = index_info.data[IndexToOffset<int64_t>::get(i,
  //           index_info)]; Reducer<scalar_t>::update(
  //               REDUCE, out_data + b * N * K + idx * K + k, src_data[i],
  //               arg_out_data + b * N * K + idx * K + k, e);
  //         }
  //       }
  //     }
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
80
81
82
83
  //     if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
  //       out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE),
  //       (scalar_t)0);
  //   });
  // });
rusty1s's avatar
rusty1s committed
84

rusty1s's avatar
rusty1s committed
85
  // return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
86
}