Commit 9d4cf96d authored by rusty1s's avatar rusty1s
Browse files

fix logsumexp with out tensor

parent 778f6245
......@@ -32,7 +32,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))
if out is not None:
out = out.sub(max_per_src_element).exp()
out = out.sub_(max_value_per_index).exp_()
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
dim_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment