Commit 05665f46 authored by rusty1s's avatar rusty1s
Browse files

long to int64_t

parent bc666bf5
...@@ -4,7 +4,7 @@ from setuptools import setup, find_packages ...@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
import build # noqa import build # noqa
setup( setup(
name='pytorch_scatter', name='torch_scatter',
version='0.1', version='0.1',
description='PyTorch extension for various scatter methods', description='PyTorch extension for various scatter methods',
url='https://github.com/rusty1s/pytorch_scatter', url='https://github.com/rusty1s/pytorch_scatter',
...@@ -14,4 +14,5 @@ setup( ...@@ -14,4 +14,5 @@ setup(
setup_requires=['cffi>=1.0.0'], setup_requires=['cffi>=1.0.0'],
packages=find_packages(exclude=['build']), packages=find_packages(exclude=['build']),
ext_package='', ext_package='',
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')], ) cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')],
)
#include <TH/TH.h> #include <TH/TH.h>
#include "THTensorDimApply.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real) #define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
inline void assertIndexInBoundaries(int idx, int size, long *free) { inline void assertIndexInBoundaries(int idx, int size, int64_t *free) {
if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); } if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); }
} }
#include "generic/cpu.c" #include "generic/cpu.c"
#include "THGenerateAllTypes.h" #include "THGenerateAllTypes.h"
#include "generic/cpu.c"
#include "THGenerateHalfType.h"
void scatter_add_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim); void scatter_add_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_add_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim); void scatter_add_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_add_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_add_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim); void scatter_add_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_add_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim); void scatter_add_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_add_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim); void scatter_add_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_add_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim); void scatter_add_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_add_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim); void scatter_add_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_sub_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_sub_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_sub_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_sub_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_sub_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_sub_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_sub_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_sub_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_mul_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_mul_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_mul_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_mul_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_mul_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_mul_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_mul_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_mul_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_div_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_div_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_div_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_div_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_div_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_div_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_div_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_div_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
...@@ -3,10 +3,43 @@ ...@@ -3,10 +3,43 @@
#else #else
void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) { void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
long idx; int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride); idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter); assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride); output_data[idx] += *(input_data + i * input_stride);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment