Commit aa80bb88 authored by rusty1s's avatar rusty1s
Browse files

mean with two scatters

parent 75514469
from torch.autograd import Function
import torch
from .utils.ffi import get_func
from .utils.gen import gen
class ScatterMean(Function):
@staticmethod
def forward(ctx, out, src, index, dim):
count = src.new_zeros(out.size())
func = get_func('scatter_mean', src)
func(dim, out, index, src, count)
count[count == 0] = 1
out /= count
ctx.mark_dirty(out)
ctx.save_for_backward(index, count)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, grad_out):
index, count = ctx.saved_variables
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.gather(ctx.dim, index)
grad_src /= count.gather(ctx.dim, index)
return None, grad_src, None, None
from .add import scatter_add
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
......@@ -93,5 +65,6 @@ def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
tensor([[ 0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000],
[ 1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return ScatterMean.apply(out, src, index, dim)
out = scatter_add(src, index, dim, out, dim_size, fill_value)
count = scatter_add(torch.ones_like(src), index, dim, None, out.size(dim))
return out / count.clamp(min=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment