"research/autoencoder/autoencoder_models/Autoencoder.py" did not exist on "a8ba923c873f9848d0f6453f3e2e3fa2dd1187dc"
Commit 2a3dca8e authored by rusty1s's avatar rusty1s
Browse files

numerical stability

parent 3e409bf4
...@@ -6,14 +6,16 @@ from torch_scatter.utils.gen import gen ...@@ -6,14 +6,16 @@ from torch_scatter.utils.gen import gen
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True): def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
print('src', src.mean())
tmp = scatter_add(src, index, dim, None, dim_size) tmp = scatter_add(src, index, dim, None, dim_size)
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size) count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
mean = tmp / count.clamp(min=1) mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index))**2 var = (src - mean.gather(dim, index))
var = var * var
out = scatter_add(var, index, dim, out, dim_size) out = scatter_add(var, index, dim, out, dim_size)
out = out / (count - 1 if unbiased else count).clamp(min=1) out = out / (count - 1 if unbiased else count).clamp(min=1)
out = torch.sqrt(out) out = torch.sqrt(out.clamp(min=1e-12))
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment