"src/vscode:/vscode.git/clone" did not exist on "14af8402f0a2b0218544e0eea95824f8d6e72530"
Commit 05665f46 authored by rusty1s's avatar rusty1s
Browse files

long to int64_t

parent bc666bf5
......@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
import build # noqa
setup(
name='pytorch_scatter',
name='torch_scatter',
version='0.1',
description='PyTorch extension for various scatter methods',
url='https://github.com/rusty1s/pytorch_scatter',
......@@ -14,4 +14,5 @@ setup(
setup_requires=['cffi>=1.0.0'],
packages=find_packages(exclude=['build']),
ext_package='',
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')], )
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')],
)
#include <TH/TH.h>
#include "THTensorDimApply.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
inline void assertIndexInBoundaries(int idx, int size, long *free) {
inline void assertIndexInBoundaries(int idx, int size, int64_t *free) {
if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); }
}
#include "generic/cpu.c"
#include "THGenerateAllTypes.h"
#include "generic/cpu.c"
#include "THGenerateHalfType.h"
void scatter_add_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_add_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_add_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_add_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_add_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_add_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_add_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_add_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_sub_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_sub_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_sub_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_sub_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_sub_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_sub_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_sub_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_sub_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_mul_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_mul_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_mul_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_mul_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_mul_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_mul_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_mul_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_mul_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_div_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_div_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
/* void scatter_div_Half (THHalfTensor *output, THLongTensor *index, THHalfTensor *input, int dim); */
void scatter_div_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_div_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_div_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_div_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_div_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
......@@ -3,10 +3,43 @@
#else
void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
long idx;
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int i = 0; i < THLongTensor_size(index, dim); i++) {
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment