Commit ae31f726 authored by rusty1s's avatar rusty1s
Browse files

fix scatter_std bug

parent 3a58d6f6
...@@ -26,7 +26,7 @@ def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -26,7 +26,7 @@ def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index = broadcast(index, src, dim) index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size) tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clamp_(1) count = broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count) mean = tmp.div(count)
var = (src - mean.gather(dim, index)) var = (src - mean.gather(dim, index))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment