Commit f20d8742 authored by rusty1s's avatar rusty1s
Browse files

logsumexp cleanup

parent 49f7b401
...@@ -3,19 +3,18 @@ from torch_scatter import scatter_logsumexp ...@@ -3,19 +3,18 @@ from torch_scatter import scatter_logsumexp
def test_logsumexp(): def test_logsumexp():
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100]) inputs = torch.tensor([
src.requires_grad_() 0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0,
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) float('-inf'),
float('-inf'), 0.0
])
inputs.requires_grad_()
index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6])
splits = [2, 3, 1, 0, 2, 1, 2]
out = scatter_logsumexp(src, index) outputs = scatter_logsumexp(inputs, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1) for src, out in zip(inputs.split(splits), outputs.unbind()):
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1) assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1)
out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1)
out4 = torch.tensor(-1, dtype=torch.float)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0) outputs.backward(torch.randn_like(outputs))
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
...@@ -23,16 +23,19 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -23,16 +23,19 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
if dim_size is None: if dim_size is None:
dim_size = int(index.max()) + 1 dim_size = int(index.max()) + 1
size = src.size() size = list(src.size())
size[dim] = dim_size size[dim] = dim_size
max_value_per_index = scatter_max(src, index, dim=dim, dim_size=dim_size)[0] max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_score = src - max_per_src_element
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))
if out is not None: if out is not None:
out = out.sub_(max_per_src_element).exp_() out = out.sub(max_per_src_element).exp()
sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out, sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
dim_size) dim_size)
return sum_per_index.add_(eps).log_().add_(max_value_per_index) return sum_per_index.add_(eps).log_().add_(max_value_per_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment