Commit 5db00866 authored by rusty1s's avatar rusty1s
Browse files

faster segment csr cpu implementation

parent 3994f3ab
...@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val; scalar_t vals[K];
int64_t row_start, row_end, arg; int64_t row_start, row_end, args[K];
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info); int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
...@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset = (n / (indptr.size(-1) - 1)) * E * K; offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
val = Reducer<scalar_t, REDUCE>::init(); vals[k] = Reducer<scalar_t, REDUCE>::init();
}
for (int64_t e = row_start; e < row_end; e++) { for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + e * K + k], &arg, e); &vals[k], src_data[offset + e * K + k], &args[k], e);
} }
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, val, }
arg_out_data + n * K + k, arg, for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start); row_end - row_start);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment