Commit cf8cf0c0 authored by rusty1s's avatar rusty1s
Browse files

fix backward pass

parent 5b1737ab
......@@ -4,6 +4,7 @@ from torch_scatter import scatter_logsumexp
def test_logsumexp():
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_logsumexp(src, index)
......@@ -16,3 +17,5 @@ def test_logsumexp():
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
......@@ -4,6 +4,7 @@ from torch_scatter import scatter_log_softmax, scatter_softmax
def test_softmax():
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_softmax(src, index)
......@@ -19,9 +20,12 @@ def test_softmax():
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
def test_log_softmax():
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_log_softmax(src, index)
......@@ -36,3 +40,5 @@ def test_log_softmax():
], dim=0)
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
......@@ -4,9 +4,12 @@ from torch_scatter import scatter_std
def test_std():
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
src.requires_grad_()
index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)
out = scatter_std(src, index, dim=-1, unbiased=True)
std = src.std(dim=-1, unbiased=True)[0]
expected = torch.tensor([[std, 0], [0, std]])
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
......@@ -17,12 +17,12 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp_()
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
return recentered_scores_exp.div_(normalizing_constants)
return recentered_scores_exp.div(normalizing_constants)
@torch.jit.script
......
......@@ -27,14 +27,14 @@ def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clamp_(1)
mean = tmp.div_(count)
mean = tmp.div(count)
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)
if unbiased:
count.sub_(1).clamp_(1)
out.div_(count).sqrt_()
count = count.sub(1).clamp_(1)
out = out.div(count).sqrt()
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment