Unverified Commit fc1b1394 authored by Gerico Vidanes's avatar Gerico Vidanes Committed by GitHub
Browse files

`WITH_PYTHON` conditionals (#313)

parent 7b0aa738
...@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 14) ...@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 14)
set(TORCHSCATTER_VERSION 2.0.9) set(TORCHSCATTER_VERSION 2.0.9)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
if(WITH_CUDA) if(WITH_CUDA)
enable_language(CUDA) enable_language(CUDA)
...@@ -12,7 +13,10 @@ if(WITH_CUDA) ...@@ -12,7 +13,10 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif() endif()
find_package(Python3 COMPONENTS Development) if (WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Development)
endif()
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
file(GLOB HEADERS csrc/*.h) file(GLOB HEADERS csrc/*.h)
...@@ -22,7 +26,10 @@ if(WITH_CUDA) ...@@ -22,7 +26,10 @@ if(WITH_CUDA)
endif() endif()
add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES}) add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python3::Python) target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
if (WITH_PYTHON)
target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python)
endif()
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter) set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter)
target_include_directories(${PROJECT_NAME} INTERFACE target_include_directories(${PROJECT_NAME} INTERFACE
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define MAX_TENSORINFO_DIMS 25 #define MAX_TENSORINFO_DIMS 25
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index, segment_coo_cpu(torch::Tensor src, torch::Tensor index,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index, segment_coo_cuda(torch::Tensor src, torch::Tensor index,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
......
#include "macros.h"
#include <torch/torch.h>
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/scatter_cpu.h" #include "cpu/scatter_cpu.h"
...@@ -10,12 +13,14 @@ ...@@ -10,12 +13,14 @@
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1) if (src.dim() == 1)
......
#pragma once #pragma once
#include <torch/extension.h> #include "extensions.h"
#include "macros.h"
namespace scatter { namespace scatter {
SCATTER_API int64_t cuda_version() noexcept; SCATTER_API int64_t cuda_version() noexcept;
......
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_coo_cpu.h" #include "cpu/segment_coo_cpu.h"
...@@ -10,12 +13,14 @@ ...@@ -10,12 +13,14 @@
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index, segment_coo_fw(torch::Tensor src, torch::Tensor index,
......
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_csr_cpu.h" #include "cpu/segment_csr_cpu.h"
...@@ -10,12 +13,14 @@ ...@@ -10,12 +13,14 @@
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr, segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
......
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "scatter.h" #include "scatter.h"
#include "macros.h" #include "macros.h"
...@@ -8,12 +11,14 @@ ...@@ -8,12 +11,14 @@
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
namespace scatter { namespace scatter {
SCATTER_API int64_t cuda_version() noexcept { SCATTER_API int64_t cuda_version() noexcept {
......
...@@ -34,7 +34,7 @@ def get_extensions(): ...@@ -34,7 +34,7 @@ def get_extensions():
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
for main, suffix in product(main_files, suffices): for main, suffix in product(main_files, suffices):
define_macros = [] define_macros = [('WITH_PYTHON', None)]
if sys.platform == 'win32': if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)] define_macros += [('torchscatter_EXPORTS', None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment