Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
fc1b1394
Unverified
Commit
fc1b1394
authored
Jul 23, 2022
by
Gerico Vidanes
Committed by
GitHub
Jul 23, 2022
Browse files
`WITH_PYTHON` conditionals (#313)
parent
7b0aa738
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
42 additions
and
15 deletions
+42
-15
CMakeLists.txt
CMakeLists.txt
+9
-2
csrc/cpu/index_info.h
csrc/cpu/index_info.h
+1
-1
csrc/cpu/scatter_cpu.h
csrc/cpu/scatter_cpu.h
+1
-1
csrc/cpu/segment_coo_cpu.h
csrc/cpu/segment_coo_cpu.h
+1
-1
csrc/cpu/segment_csr_cpu.h
csrc/cpu/segment_csr_cpu.h
+1
-1
csrc/cpu/utils.h
csrc/cpu/utils.h
+1
-1
csrc/cuda/scatter_cuda.h
csrc/cuda/scatter_cuda.h
+1
-1
csrc/cuda/segment_coo_cuda.h
csrc/cuda/segment_coo_cuda.h
+1
-1
csrc/cuda/segment_csr_cuda.h
csrc/cuda/segment_csr_cuda.h
+1
-1
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+1
-1
csrc/extensions.h
csrc/extensions.h
+2
-0
csrc/scatter.cpp
csrc/scatter.cpp
+5
-0
csrc/scatter.h
csrc/scatter.h
+1
-3
csrc/segment_coo.cpp
csrc/segment_coo.cpp
+5
-0
csrc/segment_csr.cpp
csrc/segment_csr.cpp
+5
-0
csrc/version.cpp
csrc/version.cpp
+5
-0
setup.py
setup.py
+1
-1
No files found.
CMakeLists.txt
View file @
fc1b1394
...
...
@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 14)
set
(
TORCHSCATTER_VERSION 2.0.9
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
option
(
WITH_PYTHON
"Link to Python when building"
ON
)
if
(
WITH_CUDA
)
enable_language
(
CUDA
)
...
...
@@ -12,7 +13,10 @@ if(WITH_CUDA)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
endif
()
find_package
(
Python3 COMPONENTS Development
)
if
(
WITH_PYTHON
)
add_definitions
(
-DWITH_PYTHON
)
find_package
(
Python3 COMPONENTS Development
)
endif
()
find_package
(
Torch REQUIRED
)
file
(
GLOB HEADERS csrc/*.h
)
...
...
@@ -22,7 +26,10 @@ if(WITH_CUDA)
endif
()
add_library
(
${
PROJECT_NAME
}
SHARED
${
OPERATOR_SOURCES
}
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
TORCH_LIBRARIES
}
Python3::Python
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
TORCH_LIBRARIES
}
)
if
(
WITH_PYTHON
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE Python3::Python
)
endif
()
set_target_properties
(
${
PROJECT_NAME
}
PROPERTIES EXPORT_NAME TorchScatter
)
target_include_directories
(
${
PROJECT_NAME
}
INTERFACE
...
...
csrc/cpu/index_info.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
#define MAX_TENSORINFO_DIMS 25
...
...
csrc/cpu/scatter_cpu.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
scatter_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
...
...
csrc/cpu/segment_coo_cpu.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_coo_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
...
...
csrc/cpu/segment_csr_cpu.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_csr_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
indptr
,
...
...
csrc/cpu/utils.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
csrc/cuda/scatter_cuda.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
scatter_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
...
...
csrc/cuda/segment_coo_cuda.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_coo_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
...
...
csrc/cuda/segment_csr_cuda.h
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_csr_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
indptr
,
...
...
csrc/cuda/utils.cuh
View file @
fc1b1394
#pragma once
#include
<torch
/extension.h
>
#include
"..
/extension
s
.h
"
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
...
...
csrc/extensions.h
0 → 100644
View file @
fc1b1394
#include "macros.h"
#include <torch/torch.h>
csrc/scatter.cpp
View file @
fc1b1394
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
...
...
@@ -10,12 +13,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__scatter_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__scatter_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
torch
::
Tensor
broadcast
(
torch
::
Tensor
src
,
torch
::
Tensor
other
,
int64_t
dim
)
{
if
(
src
.
dim
()
==
1
)
...
...
csrc/scatter.h
View file @
fc1b1394
#pragma once
#include <torch/extension.h>
#include "macros.h"
#include "extensions.h"
namespace
scatter
{
SCATTER_API
int64_t
cuda_version
()
noexcept
;
...
...
csrc/segment_coo.cpp
View file @
fc1b1394
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
...
...
@@ -10,12 +13,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__segment_coo_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__segment_coo_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_coo_fw
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
...
...
csrc/segment_csr.cpp
View file @
fc1b1394
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
...
...
@@ -10,12 +13,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__segment_csr_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__segment_csr_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_csr_fw
(
torch
::
Tensor
src
,
torch
::
Tensor
indptr
,
...
...
csrc/version.cpp
View file @
fc1b1394
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "scatter.h"
#include "macros.h"
...
...
@@ -8,12 +11,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__version_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__version_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
#endif
namespace
scatter
{
SCATTER_API
int64_t
cuda_version
()
noexcept
{
...
...
setup.py
View file @
fc1b1394
...
...
@@ -34,7 +34,7 @@ def get_extensions():
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[]
define_macros
=
[
(
'WITH_PYTHON'
,
None
)
]
if
sys
.
platform
==
'win32'
:
define_macros
+=
[(
'torchscatter_EXPORTS'
,
None
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment