".github/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "d69928af8be66a944e057cf21638d1a58e821fb8"
Commit cf8cf0c0 authored by rusty1s's avatar rusty1s
Browse files

fix backward pass

parent 5b1737ab
...@@ -4,6 +4,7 @@ from torch_scatter import scatter_logsumexp ...@@ -4,6 +4,7 @@ from torch_scatter import scatter_logsumexp
def test_logsumexp(): def test_logsumexp():
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100]) src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_logsumexp(src, index) out = scatter_logsumexp(src, index)
...@@ -16,3 +17,5 @@ def test_logsumexp(): ...@@ -16,3 +17,5 @@ def test_logsumexp():
expected = torch.stack([out0, out1, out2, out3, out4], dim=0) expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
...@@ -4,6 +4,7 @@ from torch_scatter import scatter_log_softmax, scatter_softmax ...@@ -4,6 +4,7 @@ from torch_scatter import scatter_log_softmax, scatter_softmax
def test_softmax(): def test_softmax():
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_softmax(src, index) out = scatter_softmax(src, index)
...@@ -19,9 +20,12 @@ def test_softmax(): ...@@ -19,9 +20,12 @@ def test_softmax():
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
def test_log_softmax(): def test_log_softmax():
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_log_softmax(src, index) out = scatter_log_softmax(src, index)
...@@ -36,3 +40,5 @@ def test_log_softmax(): ...@@ -36,3 +40,5 @@ def test_log_softmax():
], dim=0) ], dim=0)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
...@@ -4,9 +4,12 @@ from torch_scatter import scatter_std ...@@ -4,9 +4,12 @@ from torch_scatter import scatter_std
def test_std(): def test_std():
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float) src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
src.requires_grad_()
index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long) index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)
out = scatter_std(src, index, dim=-1, unbiased=True) out = scatter_std(src, index, dim=-1, unbiased=True)
std = src.std(dim=-1, unbiased=True)[0] std = src.std(dim=-1, unbiased=True)[0]
expected = torch.tensor([[std, 0], [0, std]]) expected = torch.tensor([[std, 0], [0, std]])
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
...@@ -17,12 +17,12 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -17,12 +17,12 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp_() recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_sum(recentered_scores_exp, index, dim) sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
normalizing_constants = sum_per_index.add_(eps).gather(dim, index) normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
return recentered_scores_exp.div_(normalizing_constants) return recentered_scores_exp.div(normalizing_constants)
@torch.jit.script @torch.jit.script
......
...@@ -27,14 +27,14 @@ def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -27,14 +27,14 @@ def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index = broadcast(index, src, dim) index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size) tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clamp_(1) count = broadcast(count, tmp, dim).clamp_(1)
mean = tmp.div_(count) mean = tmp.div(count)
var = (src - mean.gather(dim, index)) var = (src - mean.gather(dim, index))
var = var * var var = var * var
out = scatter_sum(var, index, dim, out, dim_size) out = scatter_sum(var, index, dim, out, dim_size)
if unbiased: if unbiased:
count.sub_(1).clamp_(1) count = count.sub(1).clamp_(1)
out.div_(count).sqrt_() out = out.div(count).sqrt()
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment