std.py 683 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
import torch

from torch_scatter import scatter_add
from torch_scatter.utils.gen import gen


def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
    src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
rusty1s's avatar
rusty1s committed
9
    print('src', src.mean())
rusty1s's avatar
rusty1s committed
10
11
12
13
14

    tmp = scatter_add(src, index, dim, None, dim_size)
    count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
    mean = tmp / count.clamp(min=1)

rusty1s's avatar
rusty1s committed
15
16
    var = (src - mean.gather(dim, index))
    var = var * var
rusty1s's avatar
rusty1s committed
17
18
    out = scatter_add(var, index, dim, out, dim_size)
    out = out / (count - 1 if unbiased else count).clamp(min=1)
rusty1s's avatar
rusty1s committed
19
    out = torch.sqrt(out.clamp(min=1e-12))
rusty1s's avatar
rusty1s committed
20
21

    return out