segment_csr_cpu.cpp 4.92 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include "segment_csr_cpu.h"

#include "index_info.h"
#include "reducer.h"
#include "utils.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out,
                std::string reduce) {
  CHECK_CPU(src);
  CHECK_CPU(indptr);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());

  CHECK_INPUT(src.dim() >= indptr.dim());

  auto sizes = indptr.sizes().vec();
  for (auto i = 0; i < indptr.dim() - 1; i++)
    sizes[i] = src.size(i);
  indptr = indptr.expand(sizes);

  auto dim = indptr.dim() - 1;

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
rusty1s's avatar
rusty1s committed
30
    for (auto i = 0; i < out.dim(); i++)
rusty1s's avatar
rusty1s committed
31
32
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
33
    CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
rusty1s's avatar
rusty1s committed
34
35
  } else {
    sizes = src.sizes().vec();
rusty1s's avatar
rusty1s committed
36
    sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
46
    out = torch::empty(sizes, src.options());
  }

  torch::optional<torch::Tensor> arg_out = torch::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
    arg_out_data = arg_out.value().data_ptr<int64_t>();
  }

rusty1s's avatar
rusty1s committed
47
48
49
  if (src.numel() == 0) {
    if (!optional_out.has_value())
      out.fill_(0);
rusty1s's avatar
rusty1s committed
50
    return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
51
  }
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
54
55
56
57
58
59
  auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
  auto E = src.size(dim);

  auto indptr_info = getTensorInfo<int64_t>(indptr);
  auto stride = indptr_info.strides[indptr_info.dims - 1];
  std::vector<int64_t> args(K);
Jacob Zhong's avatar
Jacob Zhong committed
60
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68
69
70
71
72
73
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end;
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      for (auto n = 0; n < N; n++) {
        auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
        row_start = indptr_info.data[offset];
        row_end = indptr_info.data[offset + stride];

        offset = (n / (indptr.size(-1) - 1)) * E * K;
        for (auto k = 0; k < K; k++)
Stefan Ivanov's avatar
Stefan Ivanov committed
74
          vals[k] = Reducer<scalar_t, REDUCE>::init();
rusty1s's avatar
rusty1s committed
75

rusty1s's avatar
rusty1s committed
76
        for (auto e = row_start; e < row_end; e++)
rusty1s's avatar
rusty1s committed
77
          for (auto k = 0; k < K; k++)
Stefan Ivanov's avatar
Stefan Ivanov committed
78
79
            Reducer<scalar_t, REDUCE>::update(
                &vals[k], src_data[offset + e * K + k], &args[k], e);
rusty1s's avatar
rusty1s committed
80
81

        for (auto k = 0; k < K; k++)
Stefan Ivanov's avatar
Stefan Ivanov committed
82
83
84
          Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
                                           arg_out_data + n * K + k, args[k],
                                           row_end - row_start);
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
      }
    });
  });

  return std::make_tuple(out, arg_out);
}

torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
                             torch::optional<torch::Tensor> optional_out) {
  CHECK_CPU(src);
  CHECK_CPU(indptr);
  if (optional_out.has_value())
    CHECK_CPU(optional_out.value());

  CHECK_INPUT(src.dim() >= indptr.dim());

  auto sizes = indptr.sizes().vec();
  for (auto i = 0; i < indptr.dim() - 1; i++)
    sizes[i] = src.size(i);
  indptr = indptr.expand(sizes);

  auto dim = indptr.dim() - 1;
rusty1s's avatar
rusty1s committed
107
  CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
rusty1s's avatar
rusty1s committed
108
109
110
111
112
113
114
115
116
117
118

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
    for (auto i = 0; i < out.dim(); i++)
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
  } else {
    auto sizes = src.sizes().vec();
rusty1s's avatar
rusty1s committed
119
120
121
122
    if (src.numel() > 0)
      sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
    else
      sizes[dim] = 0;
rusty1s's avatar
rusty1s committed
123
124
125
    out = torch::empty(sizes, src.options());
  }

rusty1s's avatar
rusty1s committed
126
127
128
  if (src.numel() == 0) {
    if (!optional_out.has_value())
      out.fill_(0);
rusty1s's avatar
rusty1s committed
129
    return out;
rusty1s's avatar
rusty1s committed
130
  }
rusty1s's avatar
rusty1s committed
131

rusty1s's avatar
rusty1s committed
132
133
134
135
136
137
  auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
  auto K = src.numel() / N;
  auto E = out.size(dim);

  auto indptr_info = getTensorInfo<int64_t>(indptr);
  auto stride = indptr_info.strides[indptr_info.dims - 1];
Jacob Zhong's avatar
Jacob Zhong committed
138
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
rusty1s's avatar
rusty1s committed
139
140
141
142
143
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    std::vector<scalar_t> vals(K);
    int64_t row_start, row_end;
rusty1s's avatar
rusty1s committed
144
    for (auto n = 0; n < N; n++) {
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
      auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
      row_start = indptr_info.data[offset];
      row_end = indptr_info.data[offset + stride];

      for (auto k = 0; k < K; k++)
        vals[k] = src_data[n * K + k];

      offset = (n / (indptr.size(-1) - 1)) * E * K;
      for (auto e = row_start; e < row_end; e++)
        for (auto k = 0; k < K; k++)
          out_data[offset + e * K + k] = vals[k];
    }
  });

  return out;
}