Unverified Commit 66105d4b authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #132 from silent567/master

Debugging problemtic behaviors of scatter_logsumexp for -inf
parents 943cfd59 f20d8742
......@@ -3,19 +3,18 @@ from torch_scatter import scatter_logsumexp
def test_logsumexp():
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
inputs = torch.tensor([
0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0,
float('-inf'),
float('-inf'), 0.0
])
inputs.requires_grad_()
index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6])
splits = [2, 3, 1, 0, 2, 1, 2]
out = scatter_logsumexp(src, index)
outputs = scatter_logsumexp(inputs, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1)
out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1)
out4 = torch.tensor(-1, dtype=torch.float)
for src, out in zip(inputs.split(splits), outputs.unbind()):
assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
outputs.backward(torch.randn_like(outputs))
......@@ -23,18 +23,19 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
if dim_size is None:
dim_size = int(index.max()) + 1
size = src.size()
size = list(src.size())
size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size)[0]
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_score = src - max_per_src_element
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))
if out is not None:
out = out.sub_(max_per_src_element).exp_()
out = out.sub(max_per_src_element).exp()
sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out,
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
dim_size)
return sum_per_index.add_(eps).log_().add_(max_value_per_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment