Commit 573443b6 authored by rusty1s's avatar rusty1s
Browse files

rename

parent cf0f8920
......@@ -38,11 +38,11 @@ def _scatter(name, dim, *data):
return (data[0], data[3]) if has_arg_output(name) else data[0]
def index_backward(dim, index, grad, grad_arg):
def index_backward(dim, index, grad, arg_grad):
typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, grad_arg)
func(dim, output, index, grad, arg_grad)
return output
......@@ -83,8 +83,8 @@ class _Scatter(Function):
grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_arg_output(self.name):
index, grad_arg = self.saved_variables
data = (index.data, data[0], grad_arg.data)
index, arg_grad = self.saved_variables
data = (index.data, data[0], arg_grad.data)
grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed
......
......@@ -14,13 +14,13 @@ void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTens
void scatter_div_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_div_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_mean_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *output_count);
void scatter_mean_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *output_count);
void scatter_mean_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *output_count);
void scatter_mean_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *output_count);
void scatter_mean_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *output_count);
void scatter_mean_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *output_count);
void scatter_mean_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *output_count);
void scatter_mean_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *num_output);
void scatter_mean_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *num_output);
void scatter_mean_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *num_output);
void scatter_mean_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *num_output);
void scatter_mean_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *num_output);
void scatter_mean_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *num_output);
void scatter_mean_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *num_output);
void scatter_max_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg_output);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment