Commit 2a3dca8e authored by rusty1s's avatar rusty1s
Browse files

numerical stability

parent 3e409bf4
......@@ -6,14 +6,16 @@ from torch_scatter.utils.gen import gen
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
print('src', src.mean())
tmp = scatter_add(src, index, dim, None, dim_size)
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index))**2
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_add(var, index, dim, out, dim_size)
out = out / (count - 1 if unbiased else count).clamp(min=1)
out = torch.sqrt(out)
out = torch.sqrt(out.clamp(min=1e-12))
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment