"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "90e63784f3d41909a713cf68ac381ee5fd94ff9a"
Commit aeb47792 authored by rusty1s's avatar rusty1s
Browse files

scatter add uses pytorch impl

parent 1270e840
import torch
from .scatter import scatter from .scatter import scatter
from .utils import gen_filled_tensor, gen_output from .utils import gen_filled_tensor, gen_output
...@@ -5,7 +7,7 @@ from .utils import gen_filled_tensor, gen_output ...@@ -5,7 +7,7 @@ from .utils import gen_filled_tensor, gen_output
def scatter_add_(output, index, input, dim=0): def scatter_add_(output, index, input, dim=0):
"""If multiple indices reference the same location, their contributions """If multiple indices reference the same location, their contributions
add.""" add."""
return scatter('add', dim, output, index, input) return output.scatter_add_(dim, index, input)
def scatter_add(index, input, dim=0, max_index=None, fill_value=0): def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
...@@ -16,7 +18,7 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0): ...@@ -16,7 +18,7 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
def scatter_sub_(output, index, input, dim=0): def scatter_sub_(output, index, input, dim=0):
"""If multiple indices reference the same location, their negated """If multiple indices reference the same location, their negated
contributions add.""" contributions add."""
return scatter('sub', dim, output, index, input) return output.scatter_add_(dim, index, -input)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0): def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
......
void scatter_add_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_add_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_add_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_add_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_add_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_add_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_add_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_sub_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_sub_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_sub_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_sub_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_sub_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_sub_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_sub_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_mul_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input); void scatter_mul_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_mul_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input); void scatter_mul_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_mul_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input); void scatter_mul_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
......
...@@ -2,22 +2,6 @@ ...@@ -2,22 +2,6 @@
#define TH_GENERIC_FILE "generic/cpu.c" #define TH_GENERIC_FILE "generic/cpu.c"
#else #else
void scatter_(add)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] += input_data[i];
})
}
void scatter_(sub)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] -= input_data[i];
})
}
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) { void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment