Commit 919572e3 authored by rusty1s's avatar rusty1s
Browse files

renames

parent f655f536
...@@ -14,15 +14,15 @@ def test_scatter_mean(str): ...@@ -14,15 +14,15 @@ def test_scatter_mean(str):
index = torch.LongTensor(index) index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0) output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]] expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_output_index = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]] expected_output_arg = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
_, output_index = scatter_max_(output, index, input, dim=1) _, output_arg = scatter_max_(output, index, input, dim=1)
assert output.tolist() == expected_output assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index assert output_arg.tolist() == expected_output_arg
output, output_index = scatter_max(index, input, dim=1) output, output_arg = scatter_max(index, input, dim=1)
assert output.tolist() == expected_output assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index assert output_arg.tolist() == expected_output_arg
output = Variable(output).fill_(0) output = Variable(output).fill_(0)
index = Variable(index) index = Variable(index)
......
...@@ -52,8 +52,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0): ...@@ -52,8 +52,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0): def scatter_max_(output, index, input, dim=0):
output_index = gen_filled_tensor(index, output.size(), fill_value=-1) output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, output_index) return scatter('max', dim, output, index, input, output_arg)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0): def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
...@@ -62,8 +62,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0): ...@@ -62,8 +62,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0): def scatter_min_(output, index, input, dim=0):
output_index = gen_filled_tensor(index, output.size(), fill_value=-1) output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, output_index) return scatter('min', dim, output, index, input, output_arg)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0): def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
......
...@@ -6,7 +6,7 @@ from torch.autograd import Function ...@@ -6,7 +6,7 @@ from torch.autograd import Function
from .._ext import ffi from .._ext import ffi
def has_output_index(name): def has_output_arg(name):
return name in ['max', 'min'] return name in ['max', 'min']
...@@ -35,14 +35,14 @@ def _scatter(name, dim, *data): ...@@ -35,14 +35,14 @@ def _scatter(name, dim, *data):
typename = type(data[0]).__name__.replace('Tensor', '') typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename)) func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data) func(dim, *data)
return (data[0], data[3]) if has_output_index(name) else data[0] return (data[0], data[3]) if has_output_arg(name) else data[0]
def index_backward(dim, index, grad, grad_index): def index_backward(dim, index, grad, grad_arg):
typename = type(grad).__name__.replace('Tensor', '') typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename)) func = getattr(ffi, 'index_backward_{}'.format(typename))
output = grad.new(index.size()).fill_(0) output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, grad_index) func(dim, output, index, grad, grad_arg)
return output return output
...@@ -62,8 +62,8 @@ class _Scatter(Function): ...@@ -62,8 +62,8 @@ class _Scatter(Function):
# `scatter_min` and `scatter_max` additionally return the `argmax` # `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the # respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass. # `output_arg` for the backward pass.
if has_output_index(self.name): if has_output_arg(self.name):
self.save_for_backward(data[1], data[3]) self.save_for_backward(data[1], data[3])
return data[0], data[3] return data[0], data[3]
else: else:
...@@ -78,13 +78,13 @@ class _Scatter(Function): ...@@ -78,13 +78,13 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or # Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used. # `scatter_min` was used.
if self.needs_input_grad[2] and not has_output_index(self.name): if self.needs_input_grad[2] and not has_output_arg(self.name):
index, = self.saved_variables index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_output_index(self.name): if self.needs_input_grad[2] and has_output_arg(self.name):
index, grad_index = self.saved_variables index, grad_arg = self.saved_variables
data = (index.data, data[0], grad_index.data) data = (index.data, data[0], grad_arg.data)
grad_input = index_backward(self.dim, *data) grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed # Return and fill with empty grads for none-differentiable passed
......
...@@ -38,26 +38,26 @@ void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, T ...@@ -38,26 +38,26 @@ void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, T
void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *output_count); void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *output_count);
void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_count); void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_count);
void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_index); void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_arg);
void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_index); void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_arg);
void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_index); void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_arg);
void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_index); void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_arg);
void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index); void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_arg);
void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index); void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_arg);
void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index); void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_arg);
void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_index); void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_arg);
void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_index); void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_arg);
void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_index); void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_arg);
void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_index); void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_arg);
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index); void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_arg);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index); void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_arg);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index); void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_arg);
void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *grad_index); void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *grad_arg);
void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *grad_index); void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *grad_arg);
void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *grad_index); void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *grad_arg);
void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *grad_index); void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *grad_arg);
void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *grad_index); void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *grad_arg);
void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *grad_index); void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *grad_arg);
void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *grad_index); void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *grad_arg);
...@@ -43,32 +43,32 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in ...@@ -43,32 +43,32 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
}) })
} }
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) { void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_arg) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_arg, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] >= output_data[index_data[i]]) { if (input_data[i] >= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i]; output_data[index_data[i]] = input_data[i];
output_index_data[index_data[i]] = i; output_arg_data[index_data[i]] = i;
} }
}) })
} }
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) { void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_arg) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_arg, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] <= output_data[index_data[i]]) { if (input_data[i] <= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i]; output_data[index_data[i]] = input_data[i];
output_index_data[index_data[i]] = i; output_arg_data[index_data[i]] = i;
} }
}) })
} }
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) { void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_arg) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_arg, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
if (grad_index_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]]; if (grad_arg_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment