Commit 5db00866 authored by rusty1s's avatar rusty1s
Browse files

faster segment csr cpu implementation

parent 3994f3ab
......@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val;
int64_t row_start, row_end, arg;
scalar_t vals[K];
int64_t row_start, row_end, args[K];
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
......@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) {
val = Reducer<scalar_t, REDUCE>::init();
vals[k] = Reducer<scalar_t, REDUCE>::init();
}
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + e * K + k], &arg, e);
&vals[k], src_data[offset + e * K + k], &args[k], e);
}
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, val,
arg_out_data + n * K + k, arg,
}
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment