segment.cpp 8.98 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/script.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
5
#include "compat.h"
#include "index_info.h"

6
7
#include <vector>

rusty1s's avatar
rusty1s committed
8
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
12
13
14
enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
rusty1s's avatar
rusty1s committed
15
16
17

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
18
19
20
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
rusty1s's avatar
rusty1s committed
21
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
22
23
24
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
rusty1s's avatar
rusty1s committed
25
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
26
27
28
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
rusty1s's avatar
rusty1s committed
29
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
30
31
32
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
rusty1s's avatar
rusty1s committed
33
34
      return __VA_ARGS__();                                                    \
    }                                                                          \
rusty1s's avatar
rusty1s committed
35
    }                                                                          \
rusty1s's avatar
rusty1s committed
36
37
  }()

rusty1s's avatar
rusty1s committed
38
39
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline scalar_t init() {
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

rusty1s's avatar
rusty1s committed
49
  static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
rusty1s's avatar
rusty1s committed
50
                            int64_t new_arg) {
rusty1s's avatar
rusty1s committed
51
    if (REDUCE == SUM || REDUCE == MEAN) {
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }

rusty1s's avatar
rusty1s committed
60
  static inline void write(scalar_t *address, scalar_t val,
rusty1s's avatar
rusty1s committed
61
                           int64_t *arg_address, int64_t arg, int count) {
rusty1s's avatar
rusty1s committed
62
    if (REDUCE == SUM) {
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (count > 0 ? count : (scalar_t)1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};

rusty1s's avatar
rusty1s committed
77
78
79
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(torch::Tensor src, torch::Tensor indptr,
            torch::optional<torch::Tensor> out_opt, std::string reduce) {
rusty1s's avatar
rusty1s committed
80
81
82
83
  CHECK_CPU(src);
  CHECK_CPU(indptr);
  if (out_opt.has_value())
    CHECK_CPU(out_opt.value());
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92
93
94
95
96

  AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");

  // Broadcasting `indptr` via `expand`.
  auto sizes = indptr.sizes().vec();
  for (int i = 0; i < indptr.dim() - 1; i++) {
    sizes[i] = src.size(i);
  }
  indptr = indptr.expand(sizes);

  src = src.contiguous();
  auto reduce_dim = indptr.dim() - 1;

rusty1s's avatar
rusty1s committed
97
  torch::Tensor out;
rusty1s's avatar
rusty1s committed
98
99
100
101
102
103
104
105
106
107
  if (out_opt.has_value()) {
    out = out_opt.value().contiguous();
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
               "Input mismatch");
  } else {
    sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
rusty1s's avatar
rusty1s committed
108
    out = torch::empty(sizes, src.options());
rusty1s's avatar
rusty1s committed
109
110
  }

rusty1s's avatar
rusty1s committed
111
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
rusty1s's avatar
rusty1s committed
112
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
113
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
114
    arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
rusty1s's avatar
rusty1s committed
115
116
117
118
119
120
121
122
123
124
125
126
127
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
  auto E = src.size(reduce_dim);

  auto indptr_info = getTensorInfo<int64_t>(indptr);
  auto stride = indptr_info.strides[indptr_info.dims - 1];
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

128
129
130
    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end;
    std::vector<int64_t> args(K);
rusty1s's avatar
rusty1s committed
131
132
133
134
135
136
137
138
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int n = 0; n < N; n++) {
        int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
        row_start = indptr_info.data[offset];
        row_end = indptr_info.data[offset + stride];

        offset = (n / (indptr.size(-1) - 1)) * E * K;
        for (int k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
139
          vals[k] = Reducer<scalar_t, REDUCE>::init();
140
141
142
        }
        for (int64_t e = row_start; e < row_end; e++) {
          for (int k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
143
            Reducer<scalar_t, REDUCE>::update(
144
                &vals[k], src_data[offset + e * K + k], &args[k], e);
rusty1s's avatar
rusty1s committed
145
          }
146
147
        }
        for (int k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
148
          Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
149
                                           arg_out_data + n * K + k, args[k],
rusty1s's avatar
rusty1s committed
150
151
152
153
154
155
156
                                           row_end - row_start);
        }
      }
    });
  });

  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
157
158
}

rusty1s's avatar
rusty1s committed
159
160
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
rusty1s's avatar
rusty1s committed
161
162
163
164
            std::string reduce) {
  CHECK_CPU(src);
  CHECK_CPU(index);
  CHECK_CPU(out);
rusty1s's avatar
rusty1s committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

  AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");

  // Broadcasting `index` via `expand`.
  auto sizes = index.sizes().vec();
  for (int i = 0; i < index.dim(); i++) {
    sizes[i] = src.size(i);
  }
  index = index.expand(sizes);

  src = src.contiguous();
  out = out.contiguous();
  auto reduce_dim = index.dim() - 1;

  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");

rusty1s's avatar
rusty1s committed
183
  torch::optional<torch::Tensor> arg_out = torch::nullopt;
rusty1s's avatar
rusty1s committed
184
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
185
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
186
    arg_out = torch::full_like(out, src.size(reduce_dim), index.options());
rusty1s's avatar
rusty1s committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto E_1 = index.numel() / src.size(reduce_dim);
  auto E_2 = src.size(reduce_dim);
  auto K = src.numel() / index.numel();
  auto N = out.size(reduce_dim);

  auto index_info = getTensorInfo<int64_t>(index);
  auto stride = index_info.strides[index_info.dims - 1];
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

201
202
203
    std::vector<scalar_t> vals(K);
    int64_t idx, next_idx, row_start;
    std::vector<int64_t> args(K);
rusty1s's avatar
rusty1s committed
204
205
206
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int e_1 = 0; e_1 < E_1; e_1++) {
        int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
207
        idx = index_info.data[offset];
rusty1s's avatar
rusty1s committed
208
209

        for (int k = 0; k < K; k++) {
210
211
          vals[k] = out_data[e_1 * N * K + k];
        }
rusty1s's avatar
rusty1s committed
212

rusty1s's avatar
rusty1s committed
213
        row_start = 0;
214
215
216
        for (int e_2 = 0; e_2 < E_2; e_2++) {

          for (int k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
217
            Reducer<scalar_t, REDUCE>::update(
218
219
                &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
          }
rusty1s's avatar
rusty1s committed
220

221
222
          if (e_2 == E_2 - 1) {
            for (int k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
223
              Reducer<scalar_t, REDUCE>::write(
224
225
                  out_data + e_1 * N * K + idx * K + k, vals[k],
                  arg_out_data + e_1 * N * K + idx * K + k, args[k],
rusty1s's avatar
rusty1s committed
226
                  e_2 + 1 - row_start);
227
228
229
            }
          } else {
            next_idx = index_info.data[offset + (e_2 + 1) * stride];
rusty1s's avatar
rusty1s committed
230
            assert(idx <= next_idx);
rusty1s's avatar
rusty1s committed
231

232
233
            if (idx != next_idx) {
              for (int k = 0; k < K; k++) {
rusty1s's avatar
rusty1s committed
234
                Reducer<scalar_t, REDUCE>::write(
235
236
                    out_data + e_1 * N * K + idx * K + k, vals[k],
                    arg_out_data + e_1 * N * K + idx * K + k, args[k],
rusty1s's avatar
rusty1s committed
237
238
                    e_2 + 1 - row_start);

239
                vals[k] = out_data[e_1 * N * K + next_idx * K + k];
rusty1s's avatar
rusty1s committed
240
              }
241
              row_start = e_2 + 1;
rusty1s's avatar
rusty1s committed
242
            }
243
244

            idx = next_idx;
rusty1s's avatar
rusty1s committed
245
246
247
248
249
250
251
          }
        }
      }
    });
  });

  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
252
253
}

rusty1s's avatar
rusty1s committed
254
255
256
static auto registry =
    torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr)
        .op("torch_scatter_cpu::segment_coo", &segment_coo);