"test/vscode:/vscode.git/clone" did not exist on "877e35d7754cd1fa60b3f1226929dbc84146ea70"
Commit 670e9f99 authored by rusty1s's avatar rusty1s
Browse files

pass out tensor to internal scatters (if present)

parent 0961269f
...@@ -22,8 +22,9 @@ def test_std(dtype, device, bias): ...@@ -22,8 +22,9 @@ def test_std(dtype, device, bias):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_empty_std(dtype, device): def test_empty_std(dtype, device):
src = tensor([], dtype, device) out = torch.zeros(1, 5, dtype=dtype, device=device)
index = tensor([], torch.long, device) src = tensor([], dtype, device).view(0, 5)
index = tensor([], torch.long, device).view(0, 5)
out = scatter_std(src, index, dim=-1) out = scatter_std(src, index, dim=0, out=out)
assert out.tolist() == [] assert out.tolist() == [[0, 0, 0, 0, 0]]
...@@ -48,8 +48,12 @@ def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True): ...@@ -48,8 +48,12 @@ def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
tmp = scatter_add(src, index, dim, None, dim_size) tmp = None if out is None else out.clone().fill_(0)
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size) tmp = scatter_add(src, index, dim, tmp, dim_size)
count = None if out is None else out.clone().fill_(0)
count = scatter_add(torch.ones_like(src), index, dim, count, dim_size)
mean = tmp / count.clamp(min=1) mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index)) var = (src - mean.gather(dim, index))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment