Commit 54d23137 authored by rusty1s's avatar rusty1s
Browse files

use last index for dim_size in segment_coo

parent 524326d0
...@@ -36,8 +36,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -36,8 +36,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
sizes[dim] = dim_size.value(); sizes[dim] = dim_size.value();
else if (index.numel() == 0) else if (index.numel() == 0)
sizes[dim] = 0; sizes[dim] = 0;
else else {
sizes[dim] = 1 + *index.max().data_ptr<int64_t>(); auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
sizes[dim] = 1 + *tmp.data_ptr<int64_t>();
}
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
......
...@@ -184,7 +184,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -184,7 +184,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
else if (index.numel() == 0) else if (index.numel() == 0)
sizes[dim] = 0; sizes[dim] = 0;
else { else {
auto d_size = index.max().data_ptr<int64_t>(); auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
auto d_size = tmp.data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t)); auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size; sizes[dim] = 1 + *h_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment