Unverified Commit fc1b1394 authored by Gerico Vidanes's avatar Gerico Vidanes Committed by GitHub
Browse files

`WITH_PYTHON` conditionals (#313)

parent 7b0aa738
......@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 14)
set(TORCHSCATTER_VERSION 2.0.9)
option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
if(WITH_CUDA)
enable_language(CUDA)
......@@ -12,7 +13,10 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()
find_package(Python3 COMPONENTS Development)
if (WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Development)
endif()
find_package(Torch REQUIRED)
file(GLOB HEADERS csrc/*.h)
......@@ -22,7 +26,10 @@ if(WITH_CUDA)
endif()
add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python3::Python)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
if (WITH_PYTHON)
target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python)
endif()
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter)
target_include_directories(${PROJECT_NAME} INTERFACE
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
#define MAX_TENSORINFO_DIMS 25
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
......
#pragma once
#include <torch/extension.h>
#include "../extensions.h"
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
......
#include "macros.h"
#include <torch/torch.h>
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
......@@ -10,12 +13,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; }
#endif
#endif
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
......
#pragma once
#include <torch/extension.h>
#include "macros.h"
#include "extensions.h"
namespace scatter {
SCATTER_API int64_t cuda_version() noexcept;
......
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
......@@ -10,12 +13,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; }
#endif
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index,
......
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
......@@ -10,12 +13,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; }
#endif
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
......
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "scatter.h"
#include "macros.h"
......@@ -8,12 +11,14 @@
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif
#endif
#endif
namespace scatter {
SCATTER_API int64_t cuda_version() noexcept {
......
......@@ -34,7 +34,7 @@ def get_extensions():
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
for main, suffix in product(main_files, suffices):
define_macros = []
define_macros = [('WITH_PYTHON', None)]
if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment