segment_csr_cpu.cpp 4.75 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include "segment_csr_cpu.h"

#include "index_info.h"
#include "reducer.h"
#include "utils.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out,
                std::string reduce) {
  CHECK_CPU(src);
  CHECK_CPU(indptr);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());

  CHECK_INPUT(src.dim() >= indptr.dim());

  auto sizes = indptr.sizes().vec();
  for (auto i = 0; i < indptr.dim() - 1; i++)
    sizes[i] = src.size(i);
  indptr = indptr.expand(sizes);

  auto dim = indptr.dim() - 1;

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
rusty1s's avatar
rusty1s committed
30
    for (auto i = 0; i < out.dim(); i++)
rusty1s's avatar
rusty1s committed
31
32
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
33
    CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
rusty1s's avatar
rusty1s committed
34
35
  } else {
    sizes = src.sizes().vec();
rusty1s's avatar
rusty1s committed
36
    sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
46
    out = torch::empty(sizes, src.options());
  }

  torch::optional<torch::Tensor> arg_out = torch::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
    arg_out_data = arg_out.value().data_ptr<int64_t>();
  }

rusty1s's avatar
rusty1s committed
47
48
49
  if (src.numel() == 0)
    return std::make_tuple(out, arg_out);

rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
  auto E = src.size(dim);

  auto indptr_info = getTensorInfo<int64_t>(indptr);
  auto stride = indptr_info.strides[indptr_info.dims - 1];
  std::vector<int64_t> args(K);
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end;
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (auto n = 0; n < N; n++) {
        auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
        row_start = indptr_info.data[offset];
        row_end = indptr_info.data[offset + stride];

        offset = (n / (indptr.size(-1) - 1)) * E * K;
        for (auto k = 0; k < K; k++)
rusty1s's avatar
rusty1s committed
71
          vals[k] = Reducer<scalar_t>::init(REDUCE);
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
        for (auto e = row_start; e < row_end; e++)
rusty1s's avatar
rusty1s committed
74
          for (auto k = 0; k < K; k++)
rusty1s's avatar
rusty1s committed
75
76
            Reducer<scalar_t>::update(
                REDUCE, &vals[k], src_data[offset + e * K + k], &args[k], e);
rusty1s's avatar
rusty1s committed
77
78

        for (auto k = 0; k < K; k++)
rusty1s's avatar
rusty1s committed
79
80
81
          Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
                                   arg_out_data + n * K + k, args[k],
                                   row_end - row_start);
rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
      }
    });
  });

  return std::make_tuple(out, arg_out);
}

torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
                             torch::optional<torch::Tensor> optional_out) {
  CHECK_CPU(src);
  CHECK_CPU(indptr);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());

  CHECK_INPUT(src.dim() >= indptr.dim());

  auto sizes = indptr.sizes().vec();
  for (auto i = 0; i < indptr.dim() - 1; i++)
    sizes[i] = src.size(i);
  indptr = indptr.expand(sizes);

  auto dim = indptr.dim() - 1;
rusty1s's avatar
rusty1s committed
104
  CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
rusty1s's avatar
rusty1s committed
105
106
107
108
109
110
111
112
113
114
115

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
    for (auto i = 0; i < out.dim(); i++)
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
  } else {
    auto sizes = src.sizes().vec();
rusty1s's avatar
rusty1s committed
116
117
118
119
    if (src.numel() > 0)
      sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
    else
      sizes[dim] = 0;
rusty1s's avatar
rusty1s committed
120
121
122
    out = torch::empty(sizes, src.options());
  }

rusty1s's avatar
rusty1s committed
123
124
125
  if (src.numel() == 0)
    return out;

rusty1s's avatar
rusty1s committed
126
127
128
129
130
131
132
133
134
135
136
137
  auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
  auto K = src.numel() / N;
  auto E = out.size(dim);

  auto indptr_info = getTensorInfo<int64_t>(indptr);
  auto stride = indptr_info.strides[indptr_info.dims - 1];
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end;
rusty1s's avatar
rusty1s committed
138
    for (auto n = 0; n < N; n++) {
rusty1s's avatar
rusty1s committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
      auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
      row_start = indptr_info.data[offset];
      row_end = indptr_info.data[offset + stride];

      for (auto k = 0; k < K; k++)
        vals[k] = src_data[n * K + k];

      offset = (n / (indptr.size(-1) - 1)) * E * K;
      for (auto e = row_start; e < row_end; e++)
        for (auto k = 0; k < K; k++)
          out_data[offset + e * K + k] = vals[k];
    }
  });

  return out;
}