segment_coo_cpu.cpp 6.39 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include "segment_coo_cpu.h"

#include "index_info.h"
#include "reducer.h"
#include "utils.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size, std::string reduce) {
  CHECK_CPU(src);
  CHECK_CPU(index);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());

  CHECK_INPUT(src.dim() >= index.dim());

  auto sizes = index.sizes().vec();
rusty1s's avatar
rusty1s committed
19
  for (auto i = 0; i < index.dim(); i++)
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
    sizes[i] = src.size(i);
  index = index.expand(sizes);

  auto dim = index.dim() - 1;

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
rusty1s's avatar
rusty1s committed
30
    for (auto i = 0; i < out.dim(); i++)
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
  } else {
    sizes = src.sizes().vec();
    if (dim_size.has_value())
      sizes[dim] = dim_size.value();
rusty1s's avatar
rusty1s committed
37
38
    else if (index.numel() == 0)
      sizes[dim] = 0;
39
40
41
42
43
    else {
      auto tmp = index.select(dim, index.size(dim) - 1);
      tmp = tmp.numel() > 1 ? tmp.max() : tmp;
      sizes[dim] = 1 + *tmp.data_ptr<int64_t>();
    }
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
    out = torch::empty(sizes, src.options());
  }

  torch::optional<torch::Tensor> arg_out = torch::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = torch::full_like(out, src.size(dim), index.options());
    arg_out_data = arg_out.value().data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
52
  } else if (reduce2REDUCE.at(reduce) == MEAN) {
rusty1s's avatar
rusty1s committed
53
54
    auto sizes = index.sizes().vec();
    sizes[dim] = out.size(dim);
rusty1s's avatar
rusty1s committed
55
    arg_out = torch::zeros(sizes, out.options());
rusty1s's avatar
rusty1s committed
56
57
  }

rusty1s's avatar
rusty1s committed
58
59
60
  if (src.numel() == 0) {
    if (!optional_out.has_value())
      out.fill_(0);
rusty1s's avatar
rusty1s committed
61
    return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
62
  }
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  auto B = index.numel() / src.size(dim);
  auto E = src.size(dim);
  auto K = src.numel() / index.numel();
  auto N = out.size(dim);

  auto index_info = getTensorInfo<int64_t>(index);
  auto stride = index_info.strides[index_info.dims - 1];
  std::vector<int64_t> args(K);
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();
    scalar_t *count_data = nullptr;

    std::vector<scalar_t> vals(K);
    int64_t idx, next_idx, row_start;
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (!optional_out.has_value())
Stefan Ivanov's avatar
Stefan Ivanov committed
81
        out.fill_(Reducer<scalar_t, REDUCE>::init());
rusty1s's avatar
rusty1s committed
82
      if (REDUCE == MEAN)
rusty1s's avatar
rusty1s committed
83
        count_data = arg_out.value().data_ptr<scalar_t>();
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92
93
94
95

      for (auto b = 0; b < B; b++) {
        auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
        idx = index_info.data[offset];

        for (auto k = 0; k < K; k++)
          vals[k] = out_data[b * N * K + k];

        row_start = 0;
        for (auto e = 0; e < E; e++) {

          for (auto k = 0; k < K; k++)
Stefan Ivanov's avatar
Stefan Ivanov committed
96
97
            Reducer<scalar_t, REDUCE>::update(
                &vals[k], src_data[b * E * K + e * K + k], &args[k], e);
rusty1s's avatar
rusty1s committed
98
99
100

          if (e == E - 1) {
            for (auto k = 0; k < K; k++)
Stefan Ivanov's avatar
Stefan Ivanov committed
101
102
              Reducer<scalar_t, REDUCE>::write(
                  out_data + b * N * K + idx * K + k, vals[k],
rusty1s's avatar
rusty1s committed
103
104
105
106
107
108
109
110
111
112
                  arg_out_data + b * N * K + idx * K + k, args[k],
                  e + 1 - row_start);
            if (REDUCE == MEAN)
              count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
          } else {
            next_idx = index_info.data[offset + (e + 1) * stride];
            assert(idx <= next_idx);

            if (idx != next_idx) {
              for (auto k = 0; k < K; k++) {
Stefan Ivanov's avatar
Stefan Ivanov committed
113
114
                Reducer<scalar_t, REDUCE>::write(
                    out_data + b * N * K + idx * K + k, vals[k],
rusty1s's avatar
rusty1s committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
                    arg_out_data + b * N * K + idx * K + k, args[k],
                    e + 1 - row_start);

                vals[k] = out_data[b * N * K + next_idx * K + k];
              }
              if (REDUCE == MEAN)
                count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
              row_start = e + 1;
            }

            idx = next_idx;
          }
        }
      }
      if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
Stefan Ivanov's avatar
Stefan Ivanov committed
130
        out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
rusty1s's avatar
rusty1s committed
131
132

      if (REDUCE == MEAN)
rusty1s's avatar
rusty1s committed
133
        arg_out.value().clamp_(1);
rusty1s's avatar
rusty1s committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    });
  });

  return std::make_tuple(out, arg_out);
}

torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
                             torch::optional<torch::Tensor> optional_out) {
  CHECK_CPU(src);
  CHECK_CPU(index);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());

  CHECK_INPUT(src.dim() >= index.dim());
  for (auto i = 0; i < index.dim() - 1; i++)
    CHECK_INPUT(src.size(i) == index.size(i));

  auto dim = index.dim() - 1;

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
    for (auto i = 0; i < src.dim(); i++)
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
  } else {
    auto sizes = src.sizes().vec();
    sizes[dim] = index.size(dim);
    out = torch::empty(sizes, src.options());
  }

rusty1s's avatar
rusty1s committed
167
168
169
  if (src.numel() == 0) {
    if (!optional_out.has_value())
      out.fill_(0);
rusty1s's avatar
rusty1s committed
170
    return out;
rusty1s's avatar
rusty1s committed
171
  }
rusty1s's avatar
rusty1s committed
172

rusty1s's avatar
rusty1s committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
  auto B = index.numel() / out.size(dim);
  auto E = index.size(dim);
  auto K = out.numel() / index.numel();
  auto N = src.size(dim);

  auto index_info = getTensorInfo<int64_t>(index);
  auto stride = index_info.strides[index_info.dims - 1];
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    std::vector<scalar_t> vals(K);
    int64_t idx, next_idx;
    for (auto b = 0; b < B; b++) {
      auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
      idx = index_info.data[offset];

      for (auto k = 0; k < K; k++)
        vals[k] = src_data[b * N * K + idx * K + k];

      for (auto e = 0; e < E; e++) {
        for (auto k = 0; k < K; k++)
          out_data[b * E * K + e * K + k] = vals[k];

        if (e < E - 1) {
          next_idx = index_info.data[offset + (e + 1) * stride];
          CHECK_INPUT(idx <= next_idx);

          if (idx != next_idx) {
            idx = next_idx;
            for (auto k = 0; k < K; k++)
              vals[k] = src_data[b * N * K + idx * K + k];
          }
        }
      }
    }
  });

  return out;
}