Commit d7353409 authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent d3857770
......@@ -56,7 +56,7 @@ def scatter_mean_(output, index, input, dim=0):
return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=1):
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mean_(output, index, input, dim)
......
......@@ -13,30 +13,29 @@ def _scatter(name, dim, *data):
class _Scatter(Function):
def __init__(self, name, dim):
super(_Scatter, self).__init__()
self.dim = dim
self.name = name
self.dim = dim
def forward(self, *data):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(data[0])
self.save_for_backward(data[1])
self.mark_dirty(data[0]) # Mark output as dirty.
self.len = len(data) # Save number of arguments for backward step
self.save_for_backward(data[1]) # Save index for backward step.
_scatter(self.name, self.dim, *data)
return data[0]
def backward(self, grad):
def backward(self, *data):
index, = self.saved_variables
grad_output = grad_input = None
if self.needs_input_grad[0]:
grad_output = grad
grad_output = data[0]
if self.needs_input_grad[2]:
grad_input = grad.gather(self.dim, index.data)
grad_input = data[0].gather(self.dim, index.data)
if len(grad) == 3:
return grad_output, None, grad_input
return grad_output, None, grad_input, None
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
def scatter(name, dim, *data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment