Commit 102b7542 authored by rusty1s's avatar rusty1s
Browse files

typos

parent 46c2a5cb
...@@ -50,8 +50,8 @@ def scatter_max_(output, index, input, dim=0): ...@@ -50,8 +50,8 @@ def scatter_max_(output, index, input, dim=0):
[torch.LongTensor of size 2x6] [torch.LongTensor of size 2x6]
) )
""" """
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1) arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, arg_output) return scatter('max', dim, output, index, input, arg)
def scatter_max(index, input, dim=0, size=None, fill_value=0): def scatter_max(index, input, dim=0, size=None, fill_value=0):
......
...@@ -44,10 +44,10 @@ def scatter_mean_(output, index, input, dim=0): ...@@ -44,10 +44,10 @@ def scatter_mean_(output, index, input, dim=0):
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000 1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6] [torch.FloatTensor of size 2x6]
""" """
num_output = gen_filled_tensor(output, output.size(), fill_value=0) count = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, num_output) scatter('mean', dim, output, index, input, count)
num_output[num_output == 0] = 1 count[count == 0] = 1
output /= num_output output /= count
return output return output
......
...@@ -50,8 +50,8 @@ def scatter_min_(output, index, input, dim=0): ...@@ -50,8 +50,8 @@ def scatter_min_(output, index, input, dim=0):
[torch.LongTensor of size 2x6] [torch.LongTensor of size 2x6]
) )
""" """
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1) arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, arg_output) return scatter('min', dim, output, index, input, arg)
def scatter_min(index, input, dim=0, size=None, fill_value=0): def scatter_min(index, input, dim=0, size=None, fill_value=0):
......
...@@ -6,7 +6,7 @@ from torch.autograd import Function ...@@ -6,7 +6,7 @@ from torch.autograd import Function
from .._ext import ffi from .._ext import ffi
def has_arg_output(name): def has_arg(name):
return name in ['max', 'min'] return name in ['max', 'min']
...@@ -36,7 +36,7 @@ def _scatter(name, dim, *data): ...@@ -36,7 +36,7 @@ def _scatter(name, dim, *data):
cuda = 'cuda_' if data[0].is_cuda else '' cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename)) func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data) func(dim, *data)
return (data[0], data[3]) if has_arg_output(name) else data[0] return (data[0], data[3]) if has_arg(name) else data[0]
def index_backward(dim, index, grad, arg): def index_backward(dim, index, grad, arg):
...@@ -63,9 +63,9 @@ class _Scatter(Function): ...@@ -63,9 +63,9 @@ class _Scatter(Function):
_scatter(self.name, self.dim, *data) _scatter(self.name, self.dim, *data)
# `scatter_min` and `scatter_max` additionally return the `argmax` # `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the # respectively `argmin`. Therefore, we need to save the `arg` for the
# `arg_output` for the backward pass. # backward pass.
if has_arg_output(self.name): if has_arg(self.name):
self.save_for_backward(data[1], data[3]) self.save_for_backward(data[1], data[3])
return data[0], data[3] return data[0], data[3]
else: else:
...@@ -80,13 +80,13 @@ class _Scatter(Function): ...@@ -80,13 +80,13 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or # Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used. # `scatter_min` was used.
if self.needs_input_grad[2] and not has_arg_output(self.name): if self.needs_input_grad[2] and not has_arg(self.name):
index, = self.saved_variables index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_arg_output(self.name): if self.needs_input_grad[2] and has_arg(self.name):
index, arg_grad = self.saved_variables index, arg = self.saved_variables
data = (index.data, data[0], arg_grad.data) data = (index.data, data[0], arg.data)
grad_input = index_backward(self.dim, *data) grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed # Return and fill with empty grads for none-differentiable passed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment