Commit 3e409bf4 authored by rusty1s's avatar rusty1s
Browse files

scatter_std

parent d305ecc0
...@@ -3,12 +3,20 @@ from .sub import scatter_sub ...@@ -3,12 +3,20 @@ from .sub import scatter_sub
from .mul import scatter_mul from .mul import scatter_mul
from .div import scatter_div from .div import scatter_div
from .mean import scatter_mean from .mean import scatter_mean
from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
__version__ = '1.0.4' __version__ = '1.0.4'
__all__ = [ __all__ = [
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean', 'scatter_add',
'scatter_max', 'scatter_min', '__version__' 'scatter_sub',
'scatter_mul',
'scatter_div',
'scatter_mean',
'scatter_std',
'scatter_max',
'scatter_min',
'__version__',
] ]
import torch
from torch_scatter import scatter_add
from torch_scatter.utils.gen import gen
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
tmp = scatter_add(src, index, dim, None, dim_size)
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index))**2
out = scatter_add(var, index, dim, out, dim_size)
out = out / (count - 1 if unbiased else count).clamp(min=1)
out = torch.sqrt(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment