Commit bc666bf5 authored by rusty1s's avatar rusty1s
Browse files

update for pytorch 0.3

parent b4e5e1df
#include <TH/TH.h> #include <TH/TH.h>
#include "THTensorDimApply.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real) #define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
inline void assertInBoundaries(int idx, int size, long *free) { inline void assertIndexInBoundaries(int idx, int size, long *free) {
if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); } if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); }
} }
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) { void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
long idx; long idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, long, index, dim, TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int i = 0; i < THLongTensor_size(index, dim); i++) { for (int i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride); idx = *(index_data + i * index_stride);
assertInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride); output_data[idx] += *(input_data + i * input_stride);
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment