Commit 26f5fa37 authored by rusty1s's avatar rusty1s
Browse files

added cpp boilerplate

parent 1e785d55
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/knn_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__knn(void) { return NULL; }
#endif
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return knn_cuda(x, y, ptr_x, ptr_y, k, cosine);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("No CPU version supported");
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::knn", &knn);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/nearest_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__nearest(void) { return NULL; }
#endif
torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return nearest_cuda(x, y, ptr_x, ptr_y);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("No CPU version supported");
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::nearest", &nearest);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/radius_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__radius(void) { return NULL; }
#endif
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r, int64_t max_num_neighbors) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("No CPU version supported");
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::radius", &radius);
...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4) ...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4)
try: try:
for library in [ for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler' '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler',
'_nearest', '_knn', '_radius'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment