Unverified Commit 49f7b401 authored by silent567's avatar silent567 Committed by GitHub
Browse files

Update logsumexp.py

Debugging for the corner cases involving -inf.
parent 943cfd59
...@@ -25,9 +25,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -25,9 +25,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
size = src.size() size = src.size()
size[dim] = dim_size size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0]
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment