"tests/vscode:/vscode.git/clone" did not exist on "fe6d09ae61f2281417e35f53a948b6fa898a4eba"
Commit 8c48dee0 authored by rusty1s's avatar rusty1s
Browse files

comments

parent 5ba5c620
......@@ -56,10 +56,13 @@ class _Scatter(Function):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(data[0]) # Mark output as dirty.
self.len = len(data) # Save number of arguments for backward step
self.len = len(data) # Save number of arguments for backward step.
_scatter(self.name, self.dim, *data)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass.
if _has_output_index(self.name):
self.save_for_backward(data[1], data[3])
return data[0], data[3]
......@@ -73,6 +76,8 @@ class _Scatter(Function):
if self.needs_input_grad[0]:
grad_output = data[0]
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if self.needs_input_grad[2] and not _has_output_index(self.name):
index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data)
......@@ -82,6 +87,8 @@ class _Scatter(Function):
data = (index.data, data[0], grad_index.data)
grad_input = _index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed
# arguments in forward pass.
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment