"vscode:/vscode.git/clone" did not exist on "4f8a93975f33ed3ac7e9c7f4f8d8d3b1622a0178"
Commit 19a94a91 authored by rusty1s's avatar rusty1s
Browse files

fix assertion 'index out of bounds' in case dim_size is omitted

parent 5c9462e5
...@@ -84,10 +84,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -84,10 +84,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
else if (index.numel() == 0) else if (index.numel() == 0)
sizes[dim] = 0; sizes[dim] = 0;
else { else {
auto d_size = index.max().data_ptr<int64_t>(); sizes[dim] = 1 + index.max().cpu().data_ptr<int64_t>()[0];
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
} }
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
......
...@@ -186,10 +186,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -186,10 +186,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
else { else {
auto tmp = index.select(dim, index.size(dim) - 1); auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp; tmp = tmp.numel() > 1 ? tmp.max() : tmp;
auto d_size = tmp.data_ptr<int64_t>(); sizes[dim] = 1 + tmp.cpu().data_ptr<int64_t>()[0];
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
} }
out = torch::zeros(sizes, src.options()); out = torch::zeros(sizes, src.options());
} }
......
...@@ -245,12 +245,10 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -245,12 +245,10 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
if (src.numel() > 0) { if (src.numel() > 0) {
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>(); sizes[dim] = indptr.flatten()[-1].cpu().data_ptr<int64_t>()[0];
auto h_size = (int64_t *)malloc(sizeof(int64_t)); } else {
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = *h_size;
} else
sizes[dim] = 0; sizes[dim] = 0;
}
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment