segment.cpp 8.72 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/extension.h>

rusty1s's avatar
rusty1s committed
3
4
5
#include "compat.h"
#include "index_info.h"

6
7
#include <vector>

rusty1s's avatar
rusty1s committed
8
9
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")

rusty1s's avatar
rusty1s committed
10
11
12
13
enum ReductionType { ADD, MEAN, MIN, MAX };

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
14
    ReductionType REDUCE = ADD;                                                \
rusty1s's avatar
rusty1s committed
15
    if (reduce == "add") {                                                     \
16
      REDUCE = ADD;                                                            \
rusty1s's avatar
rusty1s committed
17
18
      return __VA_ARGS__();                                                    \
    } else if (reduce == "mean") {                                             \
19
      REDUCE = MEAN;                                                           \
rusty1s's avatar
rusty1s committed
20
21
      return __VA_ARGS__();                                                    \
    } else if (reduce == "min") {                                              \
22
      REDUCE = MIN;                                                            \
rusty1s's avatar
rusty1s committed
23
24
      return __VA_ARGS__();                                                    \
    } else if (reduce == "max") {                                              \
25
      REDUCE = MAX;                                                            \
rusty1s's avatar
rusty1s committed
26
27
28
29
      return __VA_ARGS__();                                                    \
    }                                                                          \
  }()

30
31
template <typename scalar_t> struct Reducer {
  static inline scalar_t init(ReductionType REDUCE) {
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
40
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

41
  static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val) {
rusty1s's avatar
rusty1s committed
42
43
44
45
46
47
48
49
    if (REDUCE == ADD || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
    }
  }

50
  static inline void update(ReductionType REDUCE, scalar_t *val, scalar_t new_val, int64_t *arg,
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
                            int64_t new_arg) {
    if (REDUCE == ADD || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }

61
  static inline void write(ReductionType REDUCE, scalar_t *address, scalar_t val,
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                           int64_t *arg_address, int64_t arg, int count) {
    if (REDUCE == ADD) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (count > 0 ? count : (scalar_t)1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};

rusty1s's avatar
rusty1s committed
78
79
80
81
82
83
84
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
            std::string reduce) {
  CHECK_CPU(src);
  CHECK_CPU(indptr);
  if (out_opt.has_value())
    CHECK_CPU(out_opt.value());
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

  AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");

  // Broadcasting `indptr` via `expand`.
  auto sizes = indptr.sizes().vec();
  for (int i = 0; i < indptr.dim() - 1; i++) {
    sizes[i] = src.size(i);
  }
  indptr = indptr.expand(sizes);

  src = src.contiguous();
  auto reduce_dim = indptr.dim() - 1;

  at::Tensor out;
  if (out_opt.has_value()) {
    out = out_opt.value().contiguous();
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
               "Input mismatch");
  } else {
    sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }

  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
  auto E = src.size(reduce_dim);

  auto indptr_info = getTensorInfo<int64_t>(indptr);
  auto stride = indptr_info.strides[indptr_info.dims - 1];
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

129
130
131
    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end;
    std::vector<int64_t> args(K);
rusty1s's avatar
rusty1s committed
132
133
134
135
136
137
138
139
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int n = 0; n < N; n++) {
        int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
        row_start = indptr_info.data[offset];
        row_end = indptr_info.data[offset + stride];

        offset = (n / (indptr.size(-1) - 1)) * E * K;
        for (int k = 0; k < K; k++) {
140
          vals[k] = Reducer<scalar_t>::init(REDUCE);
141
142
143
        }
        for (int64_t e = row_start; e < row_end; e++) {
          for (int k = 0; k < K; k++) {
144
            Reducer<scalar_t>::update(REDUCE,
145
                &vals[k], src_data[offset + e * K + k], &args[k], e);
rusty1s's avatar
rusty1s committed
146
          }
147
148
        }
        for (int k = 0; k < K; k++) {
149
          Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
150
                                           arg_out_data + n * K + k, args[k],
rusty1s's avatar
rusty1s committed
151
152
153
154
155
156
157
                                           row_end - row_start);
        }
      }
    });
  });

  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
158
159
160
161
162
163
164
165
}

std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
            std::string reduce) {
  CHECK_CPU(src);
  CHECK_CPU(index);
  CHECK_CPU(out);
rusty1s's avatar
rusty1s committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

  AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");

  // Broadcasting `index` via `expand`.
  auto sizes = index.sizes().vec();
  for (int i = 0; i < index.dim(); i++) {
    sizes[i] = src.size(i);
  }
  index = index.expand(sizes);

  src = src.contiguous();
  out = out.contiguous();
  auto reduce_dim = index.dim() - 1;

  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");

  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), index.options());
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }

  auto E_1 = index.numel() / src.size(reduce_dim);
  auto E_2 = src.size(reduce_dim);
  auto K = src.numel() / index.numel();
  auto N = out.size(reduce_dim);

  auto index_info = getTensorInfo<int64_t>(index);
  auto stride = index_info.strides[index_info.dims - 1];
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

202
203
204
    std::vector<scalar_t> vals(K);
    int64_t idx, next_idx, row_start;
    std::vector<int64_t> args(K);
rusty1s's avatar
rusty1s committed
205
206
207
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (int e_1 = 0; e_1 < E_1; e_1++) {
        int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
208
        idx = index_info.data[offset];
rusty1s's avatar
rusty1s committed
209
210

        for (int k = 0; k < K; k++) {
211
212
          vals[k] = out_data[e_1 * N * K + k];
        }
rusty1s's avatar
rusty1s committed
213

rusty1s's avatar
rusty1s committed
214
        row_start = 0;
215
216
217
        for (int e_2 = 0; e_2 < E_2; e_2++) {

          for (int k = 0; k < K; k++) {
218
            Reducer<scalar_t>::update(REDUCE,
219
220
                &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
          }
rusty1s's avatar
rusty1s committed
221

222
223
          if (e_2 == E_2 - 1) {
            for (int k = 0; k < K; k++) {
224
              Reducer<scalar_t>::write(REDUCE,
225
226
                  out_data + e_1 * N * K + idx * K + k, vals[k],
                  arg_out_data + e_1 * N * K + idx * K + k, args[k],
rusty1s's avatar
rusty1s committed
227
                  e_2 + 1 - row_start);
228
229
230
            }
          } else {
            next_idx = index_info.data[offset + (e_2 + 1) * stride];
rusty1s's avatar
rusty1s committed
231
            assert(idx <= next_idx);
rusty1s's avatar
rusty1s committed
232

233
234
            if (idx != next_idx) {
              for (int k = 0; k < K; k++) {
235
                Reducer<scalar_t>::write(REDUCE,
236
237
                    out_data + e_1 * N * K + idx * K + k, vals[k],
                    arg_out_data + e_1 * N * K + idx * K + k, args[k],
rusty1s's avatar
rusty1s committed
238
239
                    e_2 + 1 - row_start);

240
                vals[k] = out_data[e_1 * N * K + next_idx * K + k];
rusty1s's avatar
rusty1s committed
241
              }
242
              row_start = e_2 + 1;
rusty1s's avatar
rusty1s committed
243
            }
244
245

            idx = next_idx;
rusty1s's avatar
rusty1s committed
246
247
248
249
250
251
252
          }
        }
      }
    });
  });

  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
253
254
255
256
257
258
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("segment_csr", &segment_csr, "Segment CSR (CPU)");
  m.def("segment_coo", &segment_coo, "Segment COO (CPU)");
}