"vscode:/vscode.git/clone" did not exist on "1626374d318c1e5253bfeb8ec9ef80473a807d65"
Commit b8a3c55c authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.3 support

parent 08dda1ad
...@@ -17,7 +17,7 @@ before_install: ...@@ -17,7 +17,7 @@ before_install:
- export CXX="g++-4.9" - export CXX="g++-4.9"
install: install:
- pip install numpy - pip install numpy
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html - pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install pycodestyle - pip install pycodestyle
- pip install flake8 - pip install flake8
- pip install codecov - pip install codecov
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
...@@ -2,23 +2,25 @@ ...@@ -2,23 +2,25 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \ #define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \ [&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \ auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \ auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\ \
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \ auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \ auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\ \
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \ auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \ auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
while (!has_finished) { \ while (!has_finished) { \
...@@ -59,25 +61,25 @@ ...@@ -59,25 +61,25 @@
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \ #define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \ TENSOR4, DIM, CODE) \
[&] { \ [&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \ auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \ auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\ \
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \ auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \ auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\ \
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \ auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \ TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \ auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \ auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \ auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto counter = zeros.data<int64_t>(); \ auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
while (!has_finished) { \ while (!has_finished) { \
......
...@@ -3,14 +3,19 @@ from setuptools import setup, find_packages ...@@ -3,14 +3,19 @@ from setuptools import setup, find_packages
import torch import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = [] extra_compile_args = []
if platform.system() != 'Windows': if platform.system() != 'Windows':
extra_compile_args += ['-Wno-unused-variable'] extra_compile_args += ['-Wno-unused-variable']
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ ext_modules = [
CppExtension( CppExtension('torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'], extra_compile_args=extra_compile_args)
extra_compile_args=extra_compile_args)
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
...@@ -20,7 +25,7 @@ if CUDA_HOME is not None: ...@@ -20,7 +25,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
] ]
__version__ = '1.3.1' __version__ = '1.3.2'
url = 'https://github.com/rusty1s/pytorch_scatter' url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = [] install_requires = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment