Commit 4e2e69be authored by rusty1s's avatar rusty1s
Browse files

major clean up

parent 1bbf8bdc
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
int64_t k, int64_t n_threads){
CHECK_CPU(query);
CHECK_CPU(support);
torch::Tensor out;
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, 0, dim, 0, n_threads, k, 0);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
#include "knn_cpu.h"
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
long ind = batch[0];
long incr = 1;
for(unsigned long i=1; i < batch.size(); i++){
if(batch[i] == ind)
incr++;
else{
res[ind-batch[0]] = incr;
incr =1;
ind = batch[i];
}
}
res[ind-batch[0]] = incr;
#include "utils.h"
#include "utils/neighbors.cpp"
torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers) {
CHECK_CPU(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CPU(y);
CHECK_INPUT(y.dim() == 2);
if (ptr_x.has_value()) {
CHECK_CPU(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
}
if (ptr_y.has_value()) {
CHECK_CPU(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
}
std::vector<size_t> *out_vec = new std::vector<size_t>();
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "radius_cpu", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto y_data = y.data_ptr<scalar_t>();
auto x_vec = std::vector<scalar_t>(x_data, x_data + x.numel());
auto y_vec = std::vector<scalar_t>(y_data, y_data + y.numel());
if (!ptr_x.has_value()) {
nanoflann_neighbors<scalar_t>(y_vec, x_vec, out_vec, 0, x.size(-1), 0,
num_workers, k, 0);
} else {
auto sx = (ptr_x.value().narrow(0, 1, ptr_x.value().numel() - 1) -
ptr_x.value().narrow(0, 0, ptr_x.value().numel() - 1));
auto sy = (ptr_y.value().narrow(0, 1, ptr_y.value().numel() - 1) -
ptr_y.value().narrow(0, 0, ptr_y.value().numel() - 1));
auto sx_data = sx.data_ptr<int64_t>();
auto sy_data = sy.data_ptr<int64_t>();
auto sx_vec = std::vector<long>(sx_data, sx_data + sx.numel());
auto sy_vec = std::vector<long>(sy_data, sy_data + sy.numel());
batch_nanoflann_neighbors<scalar_t>(y_vec, x_vec, sy_vec, sx_vec, out_vec,
k, x.size(-1), 0, k, 0);
}
});
const int64_t size = out_vec->size() / 2;
auto out = torch::from_blob(out_vec->data(), {size, 2},
x.options().dtype(torch::kLong));
return out.t().index_select(0, torch::tensor({1, 0}));
}
torch::Tensor batch_knn_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
int64_t k) {
CHECK_CPU(query);
CHECK_CPU(support);
CHECK_CPU(query_batch);
CHECK_CPU(support_batch);
torch::Tensor out;
auto data_qb = query_batch.data_ptr<int64_t>();
auto data_sb = support_batch.data_ptr<int64_t>();
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
std::vector<long> size_query_batch_stl;
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
get_size_batch(query_batch_stl, size_query_batch_stl);
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
std::vector<long> size_support_batch_stl;
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
get_size_batch(support_batch_stl, size_support_batch_stl);
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
supports_stl,
size_query_batch_stl,
size_support_batch_stl,
neighbors_indices,
0,
dim,
0,
k, 0);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "utils/neighbors.cpp"
#include <iostream>
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
int64_t k, int64_t n_threads);
torch::Tensor batch_knn_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
int64_t k);
\ No newline at end of file
torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers);
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>
torch::Tensor radius_cpu(torch::Tensor support, torch::Tensor query,
double radius, int64_t max_num, int64_t n_threads){
CHECK_CPU(query);
CHECK_CPU(support);
torch::Tensor out;
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads, 0, 1);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
long ind = batch[0];
long incr = 1;
for(unsigned long i=1; i < batch.size(); i++){
if(batch[i] == ind)
incr++;
else{
res[ind-batch[0]] = incr;
incr =1;
ind = batch[i];
}
}
res[ind-batch[0]] = incr;
#include "utils.h"
#include "utils/neighbors.cpp"
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
CHECK_CPU(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CPU(y);
CHECK_INPUT(y.dim() == 2);
if (ptr_x.has_value()) {
CHECK_CPU(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
}
if (ptr_y.has_value()) {
CHECK_CPU(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
}
std::vector<size_t> *out_vec = new std::vector<size_t>();
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "radius_cpu", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto y_data = y.data_ptr<scalar_t>();
auto x_vec = std::vector<scalar_t>(x_data, x_data + x.numel());
auto y_vec = std::vector<scalar_t>(y_data, y_data + y.numel());
if (!ptr_x.has_value()) {
nanoflann_neighbors<scalar_t>(y_vec, x_vec, out_vec, r, x.size(-1),
max_num_neighbors, num_workers, 0, 1);
} else {
auto sx = (ptr_x.value().narrow(0, 1, ptr_x.value().numel() - 1) -
ptr_x.value().narrow(0, 0, ptr_x.value().numel() - 1));
auto sy = (ptr_y.value().narrow(0, 1, ptr_y.value().numel() - 1) -
ptr_y.value().narrow(0, 0, ptr_y.value().numel() - 1));
auto sx_data = sx.data_ptr<int64_t>();
auto sy_data = sy.data_ptr<int64_t>();
auto sx_vec = std::vector<long>(sx_data, sx_data + sx.numel());
auto sy_vec = std::vector<long>(sy_data, sy_data + sy.numel());
batch_nanoflann_neighbors<scalar_t>(y_vec, x_vec, sy_vec, sx_vec, out_vec,
r, x.size(-1), max_num_neighbors, 0,
1);
}
});
const int64_t size = out_vec->size() / 2;
auto out = torch::from_blob(out_vec->data(), {size, 2},
x.options().dtype(torch::kLong));
return out.t().index_select(0, torch::tensor({1, 0}));
}
torch::Tensor batch_radius_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
double radius, int64_t max_num) {
CHECK_CPU(query);
CHECK_CPU(support);
CHECK_CPU(query_batch);
CHECK_CPU(support_batch);
torch::Tensor out;
auto data_qb = query_batch.data_ptr<int64_t>();
auto data_sb = support_batch.data_ptr<int64_t>();
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
std::vector<long> size_query_batch_stl;
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
get_size_batch(query_batch_stl, size_query_batch_stl);
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
std::vector<long> size_support_batch_stl;
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
get_size_batch(support_batch_stl, size_support_batch_stl);
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
supports_stl,
size_query_batch_stl,
size_support_batch_stl,
neighbors_indices,
radius,
dim,
max_num,
0, 1);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "utils/neighbors.cpp"
#include <iostream>
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
double radius, int64_t max_num, int64_t n_threads);
torch::Tensor batch_radius_cpu(torch::Tensor query,
torch::Tensor support,
torch::Tensor query_batch,
torch::Tensor support_batch,
double radius, int64_t max_num);
\ No newline at end of file
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers);
......@@ -75,16 +75,30 @@ __global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
}
}
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine) {
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
bool cosine) {
CHECK_CUDA(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
CHECK_INPUT(y.dim() == 2);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else {
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else {
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.options());
......@@ -94,7 +108,7 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), k, x.size(1), cosine);
});
......
......@@ -2,5 +2,7 @@
#include <torch/extension.h>
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine);
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
bool cosine);
......@@ -44,26 +44,40 @@ __global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
}
}
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r,
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors) {
CHECK_CUDA(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
CHECK_INPUT(y.dim() == 2);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else {
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else {
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
auto row = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.options());
auto col = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.options());
auto row =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
auto col =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
x.size(1));
});
......
......@@ -2,6 +2,7 @@
#include <torch/extension.h>
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r,
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optiona<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors);
#include <Python.h>
#include <torch/script.h>
#include "cpu/knn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/knn_cuda.h"
#endif
#include "cpu/knn_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__knn(void) { return NULL; }
#endif
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine, int64_t n_threads) {
torch::Tensor knn(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine,
int64_t num_workers) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
else if (!(ptr_x.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = ptr_y.value();
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
return knn_cuda(x, y, ptr_x, ptr_x, k, cosine);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
if (cosine) {
if (cosine)
AT_ERROR("`cosine` argument not supported on CPU");
}
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
return knn_cpu(x,y,k,n_threads);
}
if (!(ptr_x.has_value())) {
auto batch_x = torch::zeros({torch::size(x,0)}).to(torch::kLong);
auto batch_y = ptr_y.value();
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::zeros({torch::size(y,0)}).to(torch::kLong);
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return batch_knn_cpu(x, y, batch_x, batch_y, k);
return knn_cpu(x, y, ptr_x, ptr_y, k, num_workers);
}
}
......
#include <Python.h>
#include <torch/script.h>
#include <iostream>
#include "cpu/radius_cpu.h"
#ifdef WITH_CUDA
#include "cuda/radius_cuda.h"
#endif
#include "cpu/radius_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__radius(void) { return NULL; }
#endif
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r, int64_t max_num_neighbors, int64_t n_threads) {
torch::Tensor radius(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
}
else if (!(ptr_x.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = ptr_y.value();
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
return radius_cpu(x,y,r,max_num_neighbors, n_threads);
}
if (!(ptr_x.has_value())) {
auto batch_x = torch::zeros({torch::size(x,0)}).to(torch::kLong);
auto batch_y = ptr_y.value();
return batch_radius_cpu(x, y, batch_x, batch_y, r, max_num_neighbors);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::zeros({torch::size(y,0)}).to(torch::kLong);
return batch_radius_cpu(x, y, batch_x, batch_y, r, max_num_neighbors);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return batch_radius_cpu(x, y, batch_x, batch_y, r, max_num_neighbors);
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers);
}
}
......
......@@ -57,9 +57,9 @@ def get_extensions():
return extensions
install_requires = ['scipy']
install_requires = []
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
tests_require = ['pytest', 'pytest-cov', 'scipy']
setup(
name='torch_cluster',
......
......@@ -2,7 +2,9 @@ from itertools import product
import pytest
import torch
import scipy.spatial
from torch_cluster import knn, knn_graph
from .utils import grad_dtypes, devices, tensor
......@@ -26,9 +28,11 @@ def test_knn(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)
row, col = knn(x, y, 2, batch_x, batch_y)
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
row, col = knn(x, y, 2)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 0, 1]
row, col = knn(x, y, 2, batch_x, batch_y)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 4, 5]
......@@ -48,55 +52,24 @@ def test_knn_graph(dtype, device):
], dtype, device)
row, col = knn_graph(x, k=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
row, col = knn_graph(x, k=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device):
x = torch.tensor([[-1.0320, 0.2380, 0.2380],
[-1.3050, -0.0930, 0.6420],
[-0.3190, -0.0410, 1.2150],
[1.1400, -0.5390, -0.3140],
[0.8410, 0.8290, 0.6090],
[-1.4380, -0.2420, -0.3260],
[-2.2980, 0.7160, 0.9320],
[-1.3680, -0.4390, 0.1380],
[-0.6710, 0.6060, 1.1800],
[0.3950, -0.0790, 1.4920]],).to(device)
k = 3
truth = set({(4, 8), (2, 8), (9, 8), (8, 0), (0, 7), (2, 1), (9, 4),
(5, 1), (4, 9), (2, 9), (8, 1), (1, 5), (5, 0), (3, 2),
(8, 2), (7, 1), (6, 0), (3, 9), (0, 5), (7, 5), (4, 2),
(1, 0), (0, 1), (7, 0), (6, 8), (9, 2), (6, 1), (5, 7),
(1, 7), (3, 4)})
row, col = knn_graph(x, k=k, flow='target_to_source',
batch=None, n_threads=24, loop=False)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
row, col = knn_graph(x, k=k, flow='target_to_source',
batch=None, n_threads=12, loop=False)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
row, col = knn_graph(x, k=k, flow='target_to_source',
batch=None, n_threads=1, loop=False)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
x = torch.randn(1000, 3)
row, col = knn_graph(x, k=5, flow='target_to_source', loop=True,
num_workers=6)
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
tree = scipy.spatial.cKDTree(x.numpy())
_, col = tree.query(x.cpu(), k=5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert pred == truth
This diff is collapsed.
......@@ -48,7 +48,9 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
ptr_x: Optional[torch.Tensor] = None
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
......@@ -59,6 +61,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:])
ptr_y: Optional[torch.Tensor] = None
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.max()) + 1
......@@ -68,8 +71,6 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
ptr_y = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_y[1:])
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine,
num_workers)
......@@ -114,10 +115,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
row, col = edge_index[0], edge_index[1]
if not loop:
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
......@@ -45,7 +45,9 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
ptr_x: Optional[torch.Tensor] = None
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
......@@ -55,9 +57,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:])
else:
ptr_x = None
ptr_y: Optional[torch.Tensor] = None
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = int(batch_y.max()) + 1
......@@ -67,8 +68,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_y = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_y[1:])
else:
ptr_y = None
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, num_workers)
......@@ -113,11 +112,16 @@ def radius_graph(x: torch.Tensor, r: float,
"""
assert flow in ['source_to_target', 'target_to_source']
row, col = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
row, col = edge_index[0], edge_index[1]
if not loop:
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment