Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
b8a3c55c
Commit
b8a3c55c
authored
Oct 14, 2019
by
rusty1s
Browse files
pytorch 1.3 support
parent
08dda1ad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
14 deletions
+26
-14
.travis.yml
.travis.yml
+1
-1
cpu/compat.h
cpu/compat.h
+5
-0
cpu/dim_apply.h
cpu/dim_apply.h
+11
-9
setup.py
setup.py
+9
-4
No files found.
.travis.yml
View file @
b8a3c55c
...
@@ -17,7 +17,7 @@ before_install:
...
@@ -17,7 +17,7 @@ before_install:
-
export CXX="g++-4.9"
-
export CXX="g++-4.9"
install
:
install
:
-
pip install numpy
-
pip install numpy
-
pip install -
q torch
-f https://download.pytorch.org/whl/nightly/cpu/torch.html
-
pip install -
-pre torch torchvision
-f https://download.pytorch.org/whl/nightly/cpu/torch
_nightly
.html
-
pip install pycodestyle
-
pip install pycodestyle
-
pip install flake8
-
pip install flake8
-
pip install codecov
-
pip install codecov
...
...
cpu/compat.h
0 → 100644
View file @
b8a3c55c
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
cpu/dim_apply.h
View file @
b8a3c55c
...
@@ -2,23 +2,25 @@
...
@@ -2,23 +2,25 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include "compat.h"
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.
data
<TYPE1>();
\
TYPE1 *TENSOR1##_data = TENSOR1.
DATA_PTR
<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
\
TYPE2 *TENSOR2##_data = TENSOR2.
data
<TYPE2>();
\
TYPE2 *TENSOR2##_data = TENSOR2.
DATA_PTR
<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
\
TYPE3 *TENSOR3##_data = TENSOR3.
data
<TYPE3>();
\
TYPE3 *TENSOR3##_data = TENSOR3.
DATA_PTR
<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
\
auto dims = TENSOR1.dim(); \
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto counter = zeros.
data
<int64_t>();
\
auto counter = zeros.
DATA_PTR
<int64_t>(); \
bool has_finished = false; \
bool has_finished = false; \
\
\
while (!has_finished) { \
while (!has_finished) { \
...
@@ -59,25 +61,25 @@
...
@@ -59,25 +61,25 @@
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
TENSOR4, DIM, CODE) \
[&] { \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.
data
<TYPE1>();
\
TYPE1 *TENSOR1##_data = TENSOR1.
DATA_PTR
<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
\
TYPE2 *TENSOR2##_data = TENSOR2.
data
<TYPE2>();
\
TYPE2 *TENSOR2##_data = TENSOR2.
DATA_PTR
<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
\
TYPE3 *TENSOR3##_data = TENSOR3.
data
<TYPE3>();
\
TYPE3 *TENSOR3##_data = TENSOR3.
DATA_PTR
<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
\
TYPE4 *TENSOR4##_data = TENSOR4.
data
<TYPE4>();
\
TYPE4 *TENSOR4##_data = TENSOR4.
DATA_PTR
<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
\
auto dims = TENSOR1.dim(); \
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
auto counter = zeros.
data
<int64_t>();
\
auto counter = zeros.
DATA_PTR
<int64_t>(); \
bool has_finished = false; \
bool has_finished = false; \
\
\
while (!has_finished) { \
while (!has_finished) { \
...
...
setup.py
View file @
b8a3c55c
...
@@ -3,14 +3,19 @@ from setuptools import setup, find_packages
...
@@ -3,14 +3,19 @@ from setuptools import setup, find_packages
import
torch
import
torch
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
extra_compile_args
=
[]
extra_compile_args
=
[]
if
platform
.
system
()
!=
'Windows'
:
if
platform
.
system
()
!=
'Windows'
:
extra_compile_args
+=
[
'-Wno-unused-variable'
]
extra_compile_args
+=
[
'-Wno-unused-variable'
]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
ext_modules
=
[
ext_modules
=
[
CppExtension
(
CppExtension
(
'torch_scatter.scatter_cpu'
,
[
'cpu/scatter.cpp'
],
'torch_scatter.scatter_cpu'
,
[
'cpu/scatter.cpp'
],
extra_compile_args
=
extra_compile_args
)
extra_compile_args
=
extra_compile_args
)
]
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
...
@@ -20,7 +25,7 @@ if CUDA_HOME is not None:
...
@@ -20,7 +25,7 @@ if CUDA_HOME is not None:
[
'cuda/scatter.cpp'
,
'cuda/scatter_kernel.cu'
])
[
'cuda/scatter.cpp'
,
'cuda/scatter_kernel.cu'
])
]
]
__version__
=
'1.3.
1
'
__version__
=
'1.3.
2
'
url
=
'https://github.com/rusty1s/pytorch_scatter'
url
=
'https://github.com/rusty1s/pytorch_scatter'
install_requires
=
[]
install_requires
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment