Commit f655f536 authored by rusty1s's avatar rusty1s
Browse files

typo

parent c541e366
...@@ -6,7 +6,7 @@ from torch.autograd import Function ...@@ -6,7 +6,7 @@ from torch.autograd import Function
from .._ext import ffi from .._ext import ffi
def _has_output_index(name): def has_output_index(name):
return name in ['max', 'min'] return name in ['max', 'min']
...@@ -35,10 +35,10 @@ def _scatter(name, dim, *data): ...@@ -35,10 +35,10 @@ def _scatter(name, dim, *data):
typename = type(data[0]).__name__.replace('Tensor', '') typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename)) func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data) func(dim, *data)
return (data[0], data[3]) if _has_output_index(name) else data[0] return (data[0], data[3]) if has_output_index(name) else data[0]
def _index_backward(dim, index, grad, grad_index): def index_backward(dim, index, grad, grad_index):
typename = type(grad).__name__.replace('Tensor', '') typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename)) func = getattr(ffi, 'index_backward_{}'.format(typename))
output = grad.new(index.size()).fill_(0) output = grad.new(index.size()).fill_(0)
...@@ -63,7 +63,7 @@ class _Scatter(Function): ...@@ -63,7 +63,7 @@ class _Scatter(Function):
# `scatter_min` and `scatter_max` additionally return the `argmax` # `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the # respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass. # `output_index` for the backward pass.
if _has_output_index(self.name): if has_output_index(self.name):
self.save_for_backward(data[1], data[3]) self.save_for_backward(data[1], data[3])
return data[0], data[3] return data[0], data[3]
else: else:
...@@ -78,14 +78,14 @@ class _Scatter(Function): ...@@ -78,14 +78,14 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or # Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used. # `scatter_min` was used.
if self.needs_input_grad[2] and not _has_output_index(self.name): if self.needs_input_grad[2] and not has_output_index(self.name):
index, = self.saved_variables index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and _has_output_index(self.name): if self.needs_input_grad[2] and has_output_index(self.name):
index, grad_index = self.saved_variables index, grad_index = self.saved_variables
data = (index.data, data[0], grad_index.data) data = (index.data, data[0], grad_index.data)
grad_input = _index_backward(self.dim, *data) grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed # Return and fill with empty grads for none-differentiable passed
# arguments in forward pass. # arguments in forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment