Commit 573443b6 authored by rusty1s's avatar rusty1s
Browse files

rename

parent cf0f8920
...@@ -38,11 +38,11 @@ def _scatter(name, dim, *data): ...@@ -38,11 +38,11 @@ def _scatter(name, dim, *data):
return (data[0], data[3]) if has_arg_output(name) else data[0] return (data[0], data[3]) if has_arg_output(name) else data[0]
def index_backward(dim, index, grad, grad_arg): def index_backward(dim, index, grad, arg_grad):
typename = type(grad).__name__.replace('Tensor', '') typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename)) func = getattr(ffi, 'index_backward_{}'.format(typename))
output = grad.new(index.size()).fill_(0) output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, grad_arg) func(dim, output, index, grad, arg_grad)
return output return output
...@@ -83,8 +83,8 @@ class _Scatter(Function): ...@@ -83,8 +83,8 @@ class _Scatter(Function):
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and has_arg_output(self.name): if self.needs_input_grad[2] and has_arg_output(self.name):
index, grad_arg = self.saved_variables index, arg_grad = self.saved_variables
data = (index.data, data[0], grad_arg.data) data = (index.data, data[0], arg_grad.data)
grad_input = index_backward(self.dim, *data) grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for none-differentiable passed # Return and fill with empty grads for none-differentiable passed
......
...@@ -14,13 +14,13 @@ void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTens ...@@ -14,13 +14,13 @@ void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTens
void scatter_div_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input); void scatter_div_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_div_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input); void scatter_div_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_mean_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *output_count); void scatter_mean_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *num_output);
void scatter_mean_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *output_count); void scatter_mean_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *num_output);
void scatter_mean_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *output_count); void scatter_mean_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *num_output);
void scatter_mean_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *output_count); void scatter_mean_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *num_output);
void scatter_mean_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *output_count); void scatter_mean_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *num_output);
void scatter_mean_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *output_count); void scatter_mean_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *num_output);
void scatter_mean_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *output_count); void scatter_mean_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *num_output);
void scatter_max_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg_output); void scatter_max_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg_output);
void scatter_max_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg_output); void scatter_max_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg_output);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment