Unverified Commit dbcafbe6 authored by YoannPitarch's avatar YoannPitarch Committed by GitHub
Browse files

Export symbols (#132)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198)

* Export symbols in the DLL (similar to https://github.com/rusty1s/pytorch_sparse/pull/198

)

* Update setup.py
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 9472aef6
...@@ -15,8 +15,8 @@ endif() ...@@ -15,8 +15,8 @@ endif()
find_package(Python3 COMPONENTS Development) find_package(Python3 COMPONENTS Development)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
file(GLOB HEADERS csrc/cluster.h) file(GLOB HEADERS csrc/*.h)
file(GLOB OPERATOR_SOURCES csrc/cpu/*.h csrc/cpu/*.cpp csrc/*.cpp) file(GLOB OPERATOR_SOURCES csrc/*.* csrc/cpu/*.*)
if(WITH_CUDA) if(WITH_CUDA)
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} csrc/cuda/*.h csrc/cuda/*.cu) file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} csrc/cuda/*.h csrc/cuda/*.cu)
endif() endif()
......
...@@ -2,30 +2,38 @@ ...@@ -2,30 +2,38 @@
#include <torch/extension.h> #include <torch/extension.h>
int64_t cuda_version(); #include "macros.h"
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio, namespace cluster {
CLUSTER_API int64_t cuda_version() noexcept;
namespace detail {
CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace cluster
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start); bool random_start);
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col, CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight); torch::optional<torch::Tensor> optional_weight);
torch::Tensor grid(torch::Tensor pos, torch::Tensor size, CLUSTER_API torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start, torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end); torch::optional<torch::Tensor> optional_end);
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine); torch::Tensor ptr_y, int64_t k, bool cosine);
torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y); torch::Tensor ptr_y);
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r, int64_t max_num_neighbors); torch::Tensor ptr_y, double r, int64_t max_num_neighbors);
std::tuple<torch::Tensor, torch::Tensor> CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q); int64_t walk_length, double p, double q);
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr, CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor); int64_t count, double factor);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start); bool random_start);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight); torch::optional<torch::Tensor> optional_weight);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size, torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start, torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end); torch::optional<torch::Tensor> optional_end);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr, torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor); int64_t count, double factor);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start); torch::Tensor ratio, bool random_start);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight); torch::optional<torch::Tensor> optional_weight);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size, torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start, torch::optional<torch::Tensor> optional_start,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y); torch::Tensor ptr_x, torch::Tensor ptr_y);
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
......
#include "macros.h"
#include <torch/extension.h>
...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; } ...@@ -15,7 +15,7 @@ PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; }
#endif #endif
#endif #endif
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) { bool random_start) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment