Commit 102b7542 authored by rusty1s's avatar rusty1s
Browse files

typos

parent 46c2a5cb
......@@ -50,8 +50,8 @@ def scatter_max_(output, index, input, dim=0):
[torch.LongTensor of size 2x6]
)
"""
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, arg_output)
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, arg)
def scatter_max(index, input, dim=0, size=None, fill_value=0):
......
......@@ -44,10 +44,10 @@ def scatter_mean_(output, index, input, dim=0):
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
num_output = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, num_output)
num_output[num_output == 0] = 1
output /= num_output
count = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, count)
count[count == 0] = 1
output /= count
return output
......
......@@ -50,8 +50,8 @@ def scatter_min_(output, index, input, dim=0):
[torch.LongTensor of size 2x6]
)
"""
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, arg_output)
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, arg)
def scatter_min(index, input, dim=0, size=None, fill_value=0):
......
......@@ -6,7 +6,7 @@ from torch.autograd import Function
from .._ext import ffi
def has_arg_output(name):
def has_arg(name):
return name in ['max', 'min']
......@@ -36,7 +36,7 @@ def _scatter(name, dim, *data):
cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data)
return (data[0], data[3]) if has_arg_output(name) else data[0]
return (data[0], data[3]) if has_arg(name) else data[0]
def index_backward(dim, index, grad, arg):
......@@ -63,9 +63,9 @@ class _Scatter(Function):
_scatter(self.name, self.dim, *data)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# `arg_output` for the backward pass.
if has_arg_output(self.name):
# respectively `argmin`. Therefore, we need to save the `arg` for the
# backward pass.
if has_arg(self.name):
self.save_for_backward(data[1], data[3])
return data[0], data[3]
else:
......@@ -80,13 +80,13 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if self.needs_input_grad[2] and not has_arg_output(self.name):
if self.needs_input_grad[2] and not has_arg(self.name):
index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_arg_output(self.name):
index, arg_grad = self.saved_variables
data = (index.data, data[0], arg_grad.data)
if self.needs_input_grad[2] and has_arg(self.name):
index, arg = self.saved_variables
data = (index.data, data[0], arg.data)
grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment