Commit 105a60be authored by rusty1s's avatar rusty1s
Browse files

fix spmm for highly sparse matrices

parent fca68194
......@@ -56,8 +56,8 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
value_data = optional_value.value().data_ptr<scalar_t>();
}
int64_t grain_size =
at::internal::GRAIN_SIZE / (K * (col.numel() / M));
int64_t grain_size = at::internal::GRAIN_SIZE /
(K * std::max(col.numel() / M, (int64_t)1));
at::parallel_for(
0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment