"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "ad0094fa08f4f40661dc860c80438c8630b43dfd"
Commit aa80bb88 authored by rusty1s's avatar rusty1s
Browse files

mean with two scatters

parent 75514469
from torch.autograd import Function import torch
from .utils.ffi import get_func from .add import scatter_add
from .utils.gen import gen
class ScatterMean(Function):
@staticmethod
def forward(ctx, out, src, index, dim):
count = src.new_zeros(out.size())
func = get_func('scatter_mean', src)
func(dim, out, index, src, count)
count[count == 0] = 1
out /= count
ctx.mark_dirty(out)
ctx.save_for_backward(index, count)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, grad_out):
index, count = ctx.saved_variables
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.gather(ctx.dim, index)
grad_src /= count.gather(ctx.dim, index)
return None, grad_src, None, None
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0): def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
...@@ -93,5 +65,6 @@ def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -93,5 +65,6 @@ def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
tensor([[ 0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000], tensor([[ 0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000],
[ 1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]]) [ 1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) out = scatter_add(src, index, dim, out, dim_size, fill_value)
return ScatterMean.apply(out, src, index, dim) count = scatter_add(torch.ones_like(src), index, dim, None, out.size(dim))
return out / count.clamp(min=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment