Commit 54d23137 authored by rusty1s's avatar rusty1s
Browse files

use last index for dim_size in segment_coo

parent 524326d0
......@@ -36,8 +36,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
else {
auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
sizes[dim] = 1 + *tmp.data_ptr<int64_t>();
}
out = torch::empty(sizes, src.options());
}
......
......@@ -184,7 +184,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto d_size = index.max().data_ptr<int64_t>();
auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
auto d_size = tmp.data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment