Commit 8c48dee0 authored by rusty1s's avatar rusty1s
Browse files

comments

parent 5ba5c620
...@@ -56,10 +56,13 @@ class _Scatter(Function): ...@@ -56,10 +56,13 @@ class _Scatter(Function):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index' assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(data[0]) # Mark output as dirty. self.mark_dirty(data[0]) # Mark output as dirty.
self.len = len(data) # Save number of arguments for backward step self.len = len(data) # Save number of arguments for backward step.
_scatter(self.name, self.dim, *data) _scatter(self.name, self.dim, *data)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass.
if _has_output_index(self.name): if _has_output_index(self.name):
self.save_for_backward(data[1], data[3]) self.save_for_backward(data[1], data[3])
return data[0], data[3] return data[0], data[3]
...@@ -73,6 +76,8 @@ class _Scatter(Function): ...@@ -73,6 +76,8 @@ class _Scatter(Function):
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
grad_output = data[0] grad_output = data[0]
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if self.needs_input_grad[2] and not _has_output_index(self.name): if self.needs_input_grad[2] and not _has_output_index(self.name):
index, = self.saved_variables index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
...@@ -82,6 +87,8 @@ class _Scatter(Function): ...@@ -82,6 +87,8 @@ class _Scatter(Function):
data = (index.data, data[0], grad_index.data) data = (index.data, data[0], grad_index.data)
grad_input = _index_backward(self.dim, *data) grad_input = _index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed
# arguments in forward pass.
return (grad_output, None, grad_input) + (None, ) * (self.len - 3) return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment