Commit cf0f8920 authored by rusty1s's avatar rusty1s
Browse files

rename

parent aeb47792
...@@ -14,15 +14,15 @@ def test_scatter_mean(str): ...@@ -14,15 +14,15 @@ def test_scatter_mean(str):
index = torch.LongTensor(index) index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0) output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]] expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_output_arg = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]] expected_arg_output = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
_, output_arg = scatter_max_(output, index, input, dim=1) _, arg_output = scatter_max_(output, index, input, dim=1)
assert output.tolist() == expected_output assert output.tolist() == expected_output
assert output_arg.tolist() == expected_output_arg assert arg_output.tolist() == expected_arg_output
output, output_arg = scatter_max(index, input, dim=1) output, arg_output = scatter_max(index, input, dim=1)
assert output.tolist() == expected_output assert output.tolist() == expected_output
assert output_arg.tolist() == expected_output_arg assert arg_output.tolist() == expected_arg_output
output = Variable(output).fill_(0) output = Variable(output).fill_(0)
index = Variable(index) index = Variable(index)
......
...@@ -51,10 +51,10 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1): ...@@ -51,10 +51,10 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
def scatter_mean_(output, index, input, dim=0): def scatter_mean_(output, index, input, dim=0):
"""If multiple indices reference the same location, their """If multiple indices reference the same location, their
contributions average.""" contributions average."""
output_count = gen_filled_tensor(output, output.size(), fill_value=0) num_output = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, output_count) scatter('mean', dim, output, index, input, num_output)
output_count[output_count == 0] = 1 num_output[num_output == 0] = 1
output /= output_count output /= num_output
return output return output
...@@ -66,8 +66,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0): ...@@ -66,8 +66,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0): def scatter_max_(output, index, input, dim=0):
"""If multiple indices reference the same location, the maximal """If multiple indices reference the same location, the maximal
contribution gets taken.""" contribution gets taken."""
output_arg = gen_filled_tensor(index, output.size(), fill_value=-1) arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, output_arg) return scatter('max', dim, output, index, input, arg_output)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0): def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
...@@ -78,8 +78,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0): ...@@ -78,8 +78,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0): def scatter_min_(output, index, input, dim=0):
"""If multiple indices reference the same location, the minimal """If multiple indices reference the same location, the minimal
contribution gets taken.""" contribution gets taken."""
output_arg = gen_filled_tensor(index, output.size(), fill_value=-1) arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, output_arg) return scatter('min', dim, output, index, input, arg_output)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0): def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
......
...@@ -6,7 +6,7 @@ from torch.autograd import Function ...@@ -6,7 +6,7 @@ from torch.autograd import Function
from .._ext import ffi from .._ext import ffi
def has_output_arg(name): def has_arg_output(name):
return name in ['max', 'min'] return name in ['max', 'min']
...@@ -35,7 +35,7 @@ def _scatter(name, dim, *data): ...@@ -35,7 +35,7 @@ def _scatter(name, dim, *data):
typename = type(data[0]).__name__.replace('Tensor', '') typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename)) func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data) func(dim, *data)
return (data[0], data[3]) if has_output_arg(name) else data[0] return (data[0], data[3]) if has_arg_output(name) else data[0]
def index_backward(dim, index, grad, grad_arg): def index_backward(dim, index, grad, grad_arg):
...@@ -62,8 +62,8 @@ class _Scatter(Function): ...@@ -62,8 +62,8 @@ class _Scatter(Function):
# `scatter_min` and `scatter_max` additionally return the `argmax` # `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the # respectively `argmin`. In addition, we need to save the
# `output_arg` for the backward pass. # `arg_output` for the backward pass.
if has_output_arg(self.name): if has_arg_output(self.name):
self.save_for_backward(data[1], data[3]) self.save_for_backward(data[1], data[3])
return data[0], data[3] return data[0], data[3]
else: else:
...@@ -78,11 +78,11 @@ class _Scatter(Function): ...@@ -78,11 +78,11 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or # Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used. # `scatter_min` was used.
if self.needs_input_grad[2] and not has_output_arg(self.name): if self.needs_input_grad[2] and not has_arg_output(self.name):
index, = self.saved_variables index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_output_arg(self.name): if self.needs_input_grad[2] and has_arg_output(self.name):
index, grad_arg = self.saved_variables index, grad_arg = self.saved_variables
data = (index.data, data[0], grad_arg.data) data = (index.data, data[0], grad_arg.data)
grad_input = index_backward(self.dim, *data) grad_input = index_backward(self.dim, *data)
......
...@@ -14,34 +14,34 @@ void scatter_div_Short (int dim, THShortTensor *output, THLongTensor *index, TH ...@@ -14,34 +14,34 @@ void scatter_div_Short (int dim, THShortTensor *output, THLongTensor *index, TH
void scatter_div_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input); void scatter_div_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_div_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input); void scatter_div_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_mean_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THFloatTensor *output_count); void scatter_mean_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THFloatTensor *num_output);
void scatter_mean_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THDoubleTensor *output_count); void scatter_mean_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THDoubleTensor *num_output);
void scatter_mean_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THByteTensor *output_count); void scatter_mean_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THByteTensor *num_output);
void scatter_mean_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THCharTensor *output_count); void scatter_mean_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THCharTensor *num_output);
void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THShortTensor *output_count); void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THShortTensor *num_output);
void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *output_count); void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *num_output);
void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_count); void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *num_output);
void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_arg); void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *arg_output);
void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_arg); void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *arg_output);
void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_arg); void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *arg_output);
void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_arg); void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *arg_output);
void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_arg); void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *arg_output);
void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_arg); void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *arg_output);
void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_arg); void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *arg_output);
void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *output_arg); void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *arg_output);
void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *output_arg); void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *arg_output);
void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *output_arg); void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *arg_output);
void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *output_arg); void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *arg_output);
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_arg); void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *arg_output);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_arg); void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *arg_output);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_arg); void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *arg_output);
void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *grad_arg); void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *arg_grad);
void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *grad_arg); void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *arg_grad);
void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *grad_arg); void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *arg_grad);
void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *grad_arg); void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *arg_grad);
void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *grad_arg); void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *arg_grad);
void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *grad_arg); void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *arg_grad);
void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *grad_arg); void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *arg_grad);
void scatter_mul_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input);
void scatter_mul_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input);
void scatter_mul_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input);
void scatter_mul_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input);
void scatter_mul_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input);
void scatter_mul_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_mul_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_div_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input);
void scatter_div_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input);
void scatter_div_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input);
void scatter_div_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input);
void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input);
void scatter_div_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_div_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_mean_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *output_count);
void scatter_mean_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *output_count);
void scatter_mean_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *output_count);
void scatter_mean_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *output_count);
void scatter_mean_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *output_count);
void scatter_mean_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *output_count);
void scatter_mean_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *output_count);
void scatter_max_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaLongTensor *arg_output);
void scatter_min_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *arg_output);
void index_backward_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *grad, THCudaLongTensor *arg_grad);
void index_backward_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *grad, THCudaLongTensor *arg_grad);
void index_backward_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *grad, THCudaLongTensor *arg_grad);
void index_backward_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *grad, THCudaLongTensor *arg_grad);
void index_backward_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *grad, THCudaLongTensor *arg_grad);
void index_backward_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *grad, THCudaLongTensor *arg_grad);
void index_backward_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *grad, THCudaLongTensor *arg_grad);
...@@ -18,41 +18,41 @@ void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *inp ...@@ -18,41 +18,41 @@ void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}) })
} }
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *output_count) { void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *num_output) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, output_count, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, num_output, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] += input_data[i]; output_data[index_data[i]] += input_data[i];
output_count_data[index_data[i]]++; num_output_data[index_data[i]]++;
}) })
} }
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_arg) { void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg_output) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_arg, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg_output, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] >= output_data[index_data[i]]) { if (input_data[i] >= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i]; output_data[index_data[i]] = input_data[i];
output_arg_data[index_data[i]] = i; arg_output_data[index_data[i]] = i;
} }
}) })
} }
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_arg) { void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg_output) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_arg, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg_output, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] <= output_data[index_data[i]]) { if (input_data[i] <= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i]; output_data[index_data[i]] = input_data[i];
output_arg_data[index_data[i]] = i; arg_output_data[index_data[i]] = i;
} }
}) })
} }
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_arg) { void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *arg_grad) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_arg, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, arg_grad, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
if (grad_arg_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]]; if (arg_grad_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment