Commit 19a94a91 authored by rusty1s's avatar rusty1s
Browse files

fix assertion 'index out of bounds' in case dim_size is omitted

parent 5c9462e5
......@@ -84,10 +84,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
sizes[dim] = 1 + index.max().cpu().data_ptr<int64_t>()[0];
}
out = torch::empty(sizes, src.options());
}
......
......@@ -186,10 +186,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
else {
auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
auto d_size = tmp.data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = 1 + *h_size;
sizes[dim] = 1 + tmp.cpu().data_ptr<int64_t>()[0];
}
out = torch::zeros(sizes, src.options());
}
......
......@@ -245,12 +245,10 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
} else {
auto sizes = src.sizes().vec();
if (src.numel() > 0) {
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = *h_size;
} else
sizes[dim] = indptr.flatten()[-1].cpu().data_ptr<int64_t>()[0];
} else {
sizes[dim] = 0;
}
out = torch::empty(sizes, src.options());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment