Unverified Commit dbcafbe6 authored by YoannPitarch's avatar YoannPitarch Committed by GitHub
Browse files

Export symbols (#132)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198

)

* Update setup.py
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 9472aef6
......@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__graclus_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
......
......@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__grid_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
CLUSTER_API torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) {
......
......@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__knn_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor knn(torch::Tensor x, torch::Tensor y,
CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine,
int64_t num_workers) {
......
#pragma once
#ifdef _WIN32
#if defined(torchcluster_EXPORTS)
#define CLUSTER_API __declspec(dllexport)
#else
#define CLUSTER_API __declspec(dllimport)
#endif
#else
#define CLUSTER_API
#endif
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define CLUSTER_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define CLUSTER_INLINE_VARIABLE __declspec(selectany)
#else
#define CLUSTER_INLINE_VARIABLE __attribute__((weak))
#endif
#endif
#include <Python.h>
#include <torch/script.h>
#include "extensions.h"
#ifdef WITH_CUDA
#include "cuda/nearest_cuda.h"
#endif
......@@ -13,7 +15,7 @@ PyMODINIT_FUNC PyInit__nearest_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
......
......@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor radius(torch::Tensor x, torch::Tensor y,
CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
......
......@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__rw_cpu(void) { return NULL; }
#endif
#endif
std::tuple<torch::Tensor, torch::Tensor>
CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) {
......
......@@ -11,7 +11,7 @@ PyMODINIT_FUNC PyInit__sampler_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
......
#include <Python.h>
#include <torch/script.h>
#include "cluster.h"
#include "macros.h"
#ifdef WITH_CUDA
#include <cuda.h>
......@@ -13,13 +15,17 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif
#endif
int64_t cuda_version() {
namespace cluster {
CLUSTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return -1;
#endif
}
} // namespace sparse
static auto registry =
torch::RegisterOperators().op("torch_cluster::cuda_version", &cuda_version);
torch::RegisterOperators().op("torch_cluster::cuda_version", &cluster::cuda_version);
......@@ -34,6 +34,10 @@ def get_extensions():
for main, suffix in product(main_files, suffices):
define_macros = []
if sys.platform == 'win32':
define_macros += [('torchcluster_EXPORTS', None)]
extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment