Commit 919572e3 authored by rusty1s's avatar rusty1s
Browse files

renames

parent f655f536
......@@ -14,15 +14,15 @@ def test_scatter_mean(str):
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_output_index = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
expected_output_arg = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
_, output_index = scatter_max_(output, index, input, dim=1)
_, output_arg = scatter_max_(output, index, input, dim=1)
assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index
assert output_arg.tolist() == expected_output_arg
output, output_index = scatter_max(index, input, dim=1)
output, output_arg = scatter_max(index, input, dim=1)
assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index
assert output_arg.tolist() == expected_output_arg
output = Variable(output).fill_(0)
index = Variable(index)
......
......@@ -52,8 +52,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0):
output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, output_index)
output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, output_arg)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
......@@ -62,8 +62,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0):
output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, output_index)
output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, output_arg)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
......
......@@ -6,7 +6,7 @@ from torch.autograd import Function
from .._ext import ffi
def has_output_index(name):
def has_output_arg(name):
return name in ['max', 'min']
......@@ -35,14 +35,14 @@ def _scatter(name, dim, *data):
typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data)
return (data[0], data[3]) if has_output_index(name) else data[0]
return (data[0], data[3]) if has_output_arg(name) else data[0]
def index_backward(dim, index, grad, grad_index):
def index_backward(dim, index, grad, grad_arg):
typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, grad_index)
func(dim, output, index, grad, grad_arg)
return output
......@@ -62,8 +62,8 @@ class _Scatter(Function):
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass.
if has_output_index(self.name):
# `output_arg` for the backward pass.
if has_output_arg(self.name):
self.save_for_backward(data[1], data[3])
return data[0], data[3]
else:
......@@ -78,13 +78,13 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if self.needs_input_grad[2] and not has_output_index(self.name):
if self.needs_input_grad[2] and not has_output_arg(self.name):
index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_output_index(self.name):
index, grad_index = self.saved_variables
data = (index.data, data[0], grad_index.data)
if self.needs_input_grad[2] and has_output_arg(self.name):
index, grad_arg = self.saved_variables
data = (index.data, data[0], grad_arg.data)
grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed
......
......@@ -38,26 +38,26 @@ void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, T
void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *output_count);
void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_count);
void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_index);
void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_index);
void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_index);
void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_index);
void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index);
void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index);
void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index);
void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_arg);
void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_arg);
void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_arg);
void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_arg);
void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_arg);
void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_arg);
void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_arg);
void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_index);
void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_index);
void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_index);
void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_index);
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index);
void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_arg);
void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_arg);
void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_arg);
void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_arg);
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_arg);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_arg);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_arg);
void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *grad_index);
void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *grad_index);
void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *grad_index);
void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *grad_index);
void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *grad_index);
void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *grad_index);
void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *grad_index);
void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *grad_arg);
void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *grad_arg);
void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *grad_arg);
void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *grad_arg);
void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *grad_arg);
void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *grad_arg);
void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *grad_arg);
......@@ -43,32 +43,32 @@ void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *in
})
}
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim,
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_arg) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_arg, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] >= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i];
output_index_data[index_data[i]] = i;
output_arg_data[index_data[i]] = i;
}
})
}
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim,
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_arg) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_arg, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] <= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i];
output_index_data[index_data[i]] = i;
output_arg_data[index_data[i]] = i;
}
})
}
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim,
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_arg) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_arg, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
if (grad_index_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
if (grad_arg_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
})
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment