Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
19df6430
Commit
19df6430
authored
Dec 11, 2020
by
Ken Leidal
Browse files
build both cpu and gpu binaries so same package can run on both CPU and GPU machines
parent
981731f0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
31 deletions
+43
-31
setup.py
setup.py
+37
-31
torch_scatter/__init__.py
torch_scatter/__init__.py
+6
-0
No files found.
setup.py
View file @
19df6430
...
...
@@ -7,7 +7,7 @@ import torch
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
CUDA_HOME
is
not
None
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
WITH_CUDA
=
True
if
os
.
getenv
(
'FORCE_CPU'
,
'0'
)
==
'1'
:
...
...
@@ -17,11 +17,18 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def
get_extensions
():
extensions
=
[]
for
with_cuda
,
supername
in
[
(
False
,
"cpu"
),
(
True
,
"gpu"
),
]:
if
with_cuda
and
not
WITH_CUDA
:
continue
Extension
=
CppExtension
define_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[]}
if
WITH_CUDA
:
if
with_cuda
:
Extension
=
CUDAExtension
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
...
...
@@ -31,7 +38,6 @@ def get_extensions():
extensions_dir
=
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
extensions
=
[]
for
main
in
main_files
:
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
...
...
@@ -42,11 +48,11 @@ def get_extensions():
sources
+=
[
path
]
path
=
osp
.
join
(
extensions_dir
,
'cuda'
,
f
'
{
name
}
_cuda.cu'
)
if
WITH_CUDA
and
osp
.
exists
(
path
):
if
with_cuda
and
osp
.
exists
(
path
):
sources
+=
[
path
]
extension
=
Extension
(
'torch_scatter._
'
+
name
,
'torch_scatter._
%s_%s'
%
(
name
,
supername
)
,
sources
,
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
...
...
torch_scatter/__init__.py
View file @
19df6430
...
...
@@ -6,8 +6,14 @@ import torch
__version__
=
'2.0.5'
if
torch
.
cuda
.
is_available
():
sublib
=
"gpu"
else
:
sublib
=
"cpu"
try
:
for
library
in
[
'_version'
,
'_scatter'
,
'_segment_csr'
,
'_segment_coo'
]:
library
=
"%s_%s"
%
(
library
,
sublib
)
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
except
AttributeError
as
e
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment