Commit fc012e08 authored by rusty1s's avatar rusty1s
Browse files

bugfix for scatter mean

parent 93e779f8
...@@ -60,7 +60,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -60,7 +60,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index_dim = dim index_dim = dim
if index_dim < 0: if index_dim < 0:
index_dim = index_dim + src.dim() index_dim = index_dim + src.dim()
if index.dim() <= dim: if index.dim() <= index_dim:
index_dim = index.dim() - 1 index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment