"tests/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "70e6ab26ba3791c351f948a20b86967eb46f693b"
Commit fe98b763 authored by rusty1s's avatar rusty1s
Browse files

first c try

parent 6e42a446
__pycache__/
_ext/
build/
dist/
*.egg-info/
*.so
from os import path as osp
from torch.utils.ffi import create_extension
abs_path = osp.join(osp.dirname(osp.realpath(__file__)), 'torch_scatter')
headers = ['torch_scatter/include/scatter.h']
sources = ['torch_scatter/src/scatter.c']
includes = [osp.join(abs_path, 'include'), osp.join(abs_path, 'src')]
defines = []
extra_objects = []
with_cuda = False
ffi = create_extension(
name='torch_scatter._ext.scatter',
package=True,
verbose=True,
headers=headers,
sources=sources,
includes=includes,
define_macros=defines,
extra_objects=extra_objects,
with_cuda=with_cuda,
relative_to=__file__)
if __name__ == '__main__':
ffi.build()
from os import path as osp
from setuptools import setup, find_packages
import build # noqa
setup(
name='pytorch_scatter',
version='0.1',
description='PyTorch extension for various scatter methods',
url='https://github.com/rusty1s/pytorch_scatter',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
install_requires=['cffi>=1.0.0'],
setup_requires=['cffi>=1.0.0'],
packages=find_packages(exclude=['build']),
ext_package='',
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')], )
void scatter_add_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_add_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
void scatter_add_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_add_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_add_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_add_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_add_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/scatter.c"
#else
inline void check_(asserts)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
// Assert same dimensionality.
THArgCheck(dim >= 0 && dim < THTensor_(nDimension)(output), 4, "Index dimension is out of bounds");
THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(input), 2, "Index tensor must have same dimensions as input tensor");
THArgCheck(THTensor_(nDimension)(input) == THTensor_(nDimension)(output), 3, "Input tensor must have same dimensions as output tensor");
// Assert same tensor sizes across index and input.
THLongStorage *indexDims = THLongTensor_newSizeOf(index);
THArgCheck(THTensor_(isSize)(input, indexDims), 2, "Index tensor must have the same size as input tensor.");
THLongStorage_free(indexDims);
// Assert same tensor sizes across input and output apart from specified dimension.
for (int d = 0; d < THTensor_(nDimension)(output); d++) {
if (d != dim) THArgCheck(THTensor_(size)(output, d) == THTensor_(size)(input, d), 3, "Input tensor must have same size as output tensor apart from the specified dimension");
}
}
void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
check_(asserts)(output, index, input, dim); long idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, long, index, dim,
for (int i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
check_inBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
})
}
#endif
#include <TH/TH.h>
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
#define check_(NAME) TH_CONCAT_4(check_, NAME, _, Real)
inline void check_inBoundaries(int idx, int size, long *free) {
if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); }
}
#include "generic/scatter.c"
#include "THGenerateAllTypes.h"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment