Commit 3994f3ab authored by rusty1s's avatar rusty1s
Browse files

faster segment coo cpu implementation

parent 4a5379c4
...@@ -180,6 +180,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -180,6 +180,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
auto E = index.numel();
auto E_1 = index.numel() / src.size(reduce_dim); auto E_1 = index.numel() / src.size(reduce_dim);
auto E_2 = src.size(reduce_dim); auto E_2 = src.size(reduce_dim);
auto K = src.numel() / index.numel(); auto K = src.numel() / index.numel();
...@@ -191,41 +192,48 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -191,41 +192,48 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val; scalar_t vals[K];
int64_t idx, next_idx, row_start, arg; int64_t idx, next_idx, row_start, args[K];
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int e_1 = 0; e_1 < E_1; e_1++) { for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info); int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset];
row_start = 0;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
idx = index_info.data[offset]; vals[k] = out_data[e_1 * N * K + k];
row_start = 0; }
val = out_data[e_1 * N * K + k];
for (int e_2 = 0; e_2 < E_2; e_2++) { for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, src_data[e_1 * E_2 * K + e_2 * K + k], &arg, e_2); &vals[k], src_data[e_1 * E_2 * K + e_2 * K + k], &args[k], e_2);
}
if (e_2 == E_2 - 1) { if (e_2 == E_2 - 1) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write( Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, val, out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, arg, arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start); e_2 + 1 - row_start);
} else { }
next_idx = index_info.data[offset + (e_2 + 1) * stride]; } else {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
if (idx != next_idx) { if (idx != next_idx) {
for (int k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write( Reducer<scalar_t, REDUCE>::write(
out_data + e_1 * N * K + idx * K + k, val, out_data + e_1 * N * K + idx * K + k, vals[k],
arg_out_data + e_1 * N * K + idx * K + k, arg, arg_out_data + e_1 * N * K + idx * K + k, args[k],
e_2 + 1 - row_start); e_2 + 1 - row_start);
row_start = e_2 + 1; vals[k] = out_data[e_1 * N * K + next_idx * K + k];
val = out_data[e_1 * N * K + next_idx * K + k];
} }
row_start = e_2 + 1;
idx = next_idx;
} }
idx = next_idx;
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment