Commit 8ef7174b authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent 05665f46
......@@ -8,5 +8,3 @@ inline void assertIndexInBoundaries(int idx, int size, int64_t *free) {
#include "generic/cpu.c"
#include "THGenerateAllTypes.h"
#include "generic/cpu.c"
#include "THGenerateHalfType.h"
void scatter_add_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_add_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_add_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_add_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_add_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_add_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
......@@ -9,7 +8,6 @@ void scatter_add_Long (THLongTensor *output, THLongTensor *index, THLongTenso
void scatter_sub_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_sub_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_sub_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_sub_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_sub_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_sub_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
......@@ -18,7 +16,6 @@ void scatter_sub_Long (THLongTensor *output, THLongTensor *index, THLongTenso
void scatter_mul_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_mul_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_mul_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_mul_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_mul_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_mul_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
......@@ -27,7 +24,6 @@ void scatter_mul_Long (THLongTensor *output, THLongTensor *index, THLongTenso
void scatter_div_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_div_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_div_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_div_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_div_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_div_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
......
......@@ -4,8 +4,7 @@
void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
......@@ -15,34 +14,31 @@ void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int d
void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
output_data[idx] -= *(input_data + i * input_stride);
})
}
void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
output_data[idx] *= *(input_data + i * input_stride);
})
}
void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
output_data[idx] /= *(input_data + i * input_stride);
})
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment