Commit 670e9f99 authored by rusty1s's avatar rusty1s
Browse files

pass out tensor to internal scatters (if present)

parent 0961269f
......@@ -22,8 +22,9 @@ def test_std(dtype, device, bias):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_empty_std(dtype, device):
src = tensor([], dtype, device)
index = tensor([], torch.long, device)
out = torch.zeros(1, 5, dtype=dtype, device=device)
src = tensor([], dtype, device).view(0, 5)
index = tensor([], torch.long, device).view(0, 5)
out = scatter_std(src, index, dim=-1)
assert out.tolist() == []
out = scatter_std(src, index, dim=0, out=out)
assert out.tolist() == [[0, 0, 0, 0, 0]]
......@@ -48,8 +48,12 @@ def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
tmp = scatter_add(src, index, dim, None, dim_size)
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
tmp = None if out is None else out.clone().fill_(0)
tmp = scatter_add(src, index, dim, tmp, dim_size)
count = None if out is None else out.clone().fill_(0)
count = scatter_add(torch.ones_like(src), index, dim, count, dim_size)
mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment