Commit 4c8e5298 authored by rusty1s's avatar rusty1s
Browse files

fix set_diag for nnz=0

parent 6456fb4a
...@@ -54,6 +54,9 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col, ...@@ -54,6 +54,9 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool)); auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>(); auto mask_data = mask.data_ptr<bool>();
if (E == 0)
return mask;
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>( non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
row_data, col_data, mask_data, N, k, num_diag, E); row_data, col_data, mask_data, N, k, num_diag, E);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment