Commit d7353409 authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent d3857770
...@@ -56,7 +56,7 @@ def scatter_mean_(output, index, input, dim=0): ...@@ -56,7 +56,7 @@ def scatter_mean_(output, index, input, dim=0):
return output return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=1): def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mean_(output, index, input, dim) return scatter_mean_(output, index, input, dim)
......
...@@ -13,30 +13,29 @@ def _scatter(name, dim, *data): ...@@ -13,30 +13,29 @@ def _scatter(name, dim, *data):
class _Scatter(Function): class _Scatter(Function):
def __init__(self, name, dim): def __init__(self, name, dim):
super(_Scatter, self).__init__() super(_Scatter, self).__init__()
self.dim = dim
self.name = name self.name = name
self.dim = dim
def forward(self, *data): def forward(self, *data):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index' assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(data[0]) self.mark_dirty(data[0]) # Mark output as dirty.
self.save_for_backward(data[1]) self.len = len(data) # Save number of arguments for backward step
self.save_for_backward(data[1]) # Save index for backward step.
_scatter(self.name, self.dim, *data) _scatter(self.name, self.dim, *data)
return data[0] return data[0]
def backward(self, grad): def backward(self, *data):
index, = self.saved_variables index, = self.saved_variables
grad_output = grad_input = None grad_output = grad_input = None
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
grad_output = grad grad_output = data[0]
if self.needs_input_grad[2]: if self.needs_input_grad[2]:
grad_input = grad.gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
if len(grad) == 3: return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
return grad_output, None, grad_input
return grad_output, None, grad_input, None
def scatter(name, dim, *data): def scatter(name, dim, *data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment