Unverified Commit dbcafbe6 authored by YoannPitarch's avatar YoannPitarch Committed by GitHub
Browse files

Export symbols (#132)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198

)

* Update setup.py
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 9472aef6
...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__graclus_cpu(void) { return NULL; } ...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__graclus_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col, CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) { torch::optional<torch::Tensor> optional_weight) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__grid_cpu(void) { return NULL; } ...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__grid_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size, CLUSTER_API torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start, torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) { torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) { if (pos.device().is_cuda()) {
......
...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__knn_cpu(void) { return NULL; } ...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__knn_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor knn(torch::Tensor x, torch::Tensor y, CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine, torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine,
int64_t num_workers) { int64_t num_workers) {
......
#pragma once
#ifdef _WIN32
#if defined(torchcluster_EXPORTS)
#define CLUSTER_API __declspec(dllexport)
#else
#define CLUSTER_API __declspec(dllimport)
#endif
#else
#define CLUSTER_API
#endif
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define CLUSTER_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define CLUSTER_INLINE_VARIABLE __declspec(selectany)
#else
#define CLUSTER_INLINE_VARIABLE __attribute__((weak))
#endif
#endif
#include <Python.h> #include <Python.h>
#include <torch/script.h> #include <torch/script.h>
#include "extensions.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/nearest_cuda.h" #include "cuda/nearest_cuda.h"
#endif #endif
...@@ -13,7 +15,7 @@ PyMODINIT_FUNC PyInit__nearest_cpu(void) { return NULL; } ...@@ -13,7 +15,7 @@ PyMODINIT_FUNC PyInit__nearest_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y) { torch::Tensor ptr_y) {
if (x.device().is_cuda()) { if (x.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; } ...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor radius(torch::Tensor x, torch::Tensor y, CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r, torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) { int64_t max_num_neighbors, int64_t num_workers) {
......
...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__rw_cpu(void) { return NULL; } ...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__rw_cpu(void) { return NULL; }
#endif #endif
#endif #endif
std::tuple<torch::Tensor, torch::Tensor> CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) { int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
......
...@@ -11,7 +11,7 @@ PyMODINIT_FUNC PyInit__sampler_cpu(void) { return NULL; } ...@@ -11,7 +11,7 @@ PyMODINIT_FUNC PyInit__sampler_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr, CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) { int64_t count, double factor) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
#include <Python.h> #include <Python.h>
#include <torch/script.h> #include <torch/script.h>
#include "cluster.h"
#include "macros.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include <cuda.h> #include <cuda.h>
...@@ -13,13 +15,17 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } ...@@ -13,13 +15,17 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif #endif
#endif #endif
int64_t cuda_version() {
namespace cluster {
CLUSTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return CUDA_VERSION; return CUDA_VERSION;
#else #else
return -1; return -1;
#endif #endif
} }
} // namespace sparse
static auto registry = static auto registry =
torch::RegisterOperators().op("torch_cluster::cuda_version", &cuda_version); torch::RegisterOperators().op("torch_cluster::cuda_version", &cluster::cuda_version);
...@@ -34,6 +34,10 @@ def get_extensions(): ...@@ -34,6 +34,10 @@ def get_extensions():
for main, suffix in product(main_files, suffices): for main, suffix in product(main_files, suffices):
define_macros = [] define_macros = []
if sys.platform == 'win32':
define_macros += [('torchcluster_EXPORTS', None)]
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows: if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare'] extra_compile_args['cxx'] += ['-Wno-sign-compare']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment