Commit aeb47792 authored by rusty1s's avatar rusty1s
Browse files

scatter add uses pytorch impl

parent 1270e840
import torch
from .scatter import scatter
from .utils import gen_filled_tensor, gen_output
......@@ -5,7 +7,7 @@ from .utils import gen_filled_tensor, gen_output
def scatter_add_(output, index, input, dim=0):
"""If multiple indices reference the same location, their contributions
add."""
return scatter('add', dim, output, index, input)
return output.scatter_add_(dim, index, input)
def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
......@@ -16,7 +18,7 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
def scatter_sub_(output, index, input, dim=0):
"""If multiple indices reference the same location, their negated
contributions add."""
return scatter('sub', dim, output, index, input)
return output.scatter_add_(dim, index, -input)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
......
void scatter_add_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_add_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_add_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_add_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_add_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_add_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_add_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_sub_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_sub_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_sub_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_sub_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_sub_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_sub_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_sub_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_mul_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_mul_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_mul_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
......
......@@ -2,22 +2,6 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#else
void scatter_(add)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] += input_data[i];
})
}
void scatter_(sub)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] -= input_data[i];
})
}
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment