#include "radius_cpu.h" #include #include "utils.h" #include torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, double radius, int64_t max_num, int64_t n_threads){ CHECK_CPU(query); CHECK_CPU(support); torch::Tensor out; std::vector* neighbors_indices = new std::vector(); auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU); int max_count = 0; AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] { auto data_q = query.data_ptr(); auto data_s = support.data_ptr(); std::vector queries_stl = std::vector(data_q, data_q + query.size(0)*query.size(1)); std::vector supports_stl = std::vector(data_s, data_s + support.size(0)*support.size(1)); int dim = torch::size(query, 1); max_count = nanoflann_neighbors(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads); }); size_t* neighbors_indices_ptr = neighbors_indices->data(); const long long tsize = static_cast(neighbors_indices->size()/2); out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options); out = out.t(); return out.clone(); } void get_size_batch(const std::vector& batch, std::vector& res){ res.resize(batch[batch.size()-1]-batch[0]+1, 0); long ind = batch[0]; long incr = 1; for(unsigned long i=1; i < batch.size(); i++){ if(batch[i] == ind) incr++; else{ res[ind-batch[0]] = incr; incr =1; ind = batch[i]; } } res[ind-batch[0]] = incr; } torch::Tensor batch_radius_cpu(torch::Tensor query, torch::Tensor support, torch::Tensor query_batch, torch::Tensor support_batch, double radius, int64_t max_num) { torch::Tensor out; auto data_qb = query_batch.data_ptr(); auto data_sb = support_batch.data_ptr(); std::vector query_batch_stl = std::vector(data_qb, data_qb+query_batch.size(0)); std::vector size_query_batch_stl; get_size_batch(query_batch_stl, size_query_batch_stl); std::vector support_batch_stl = std::vector(data_sb, data_sb+support_batch.size(0)); std::vector size_support_batch_stl; get_size_batch(support_batch_stl, size_support_batch_stl); std::vector* neighbors_indices = new std::vector(); auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU); int max_count = 0; AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] { auto data_q = query.data_ptr(); auto data_s = support.data_ptr(); std::vector queries_stl = std::vector(data_q, data_q + query.size(0)*query.size(1)); std::vector supports_stl = std::vector(data_s, data_s + support.size(0)*support.size(1)); int dim = torch::size(query, 1); max_count = batch_nanoflann_neighbors(queries_stl, supports_stl, size_query_batch_stl, size_support_batch_stl, neighbors_indices, radius, dim, max_num ); }); size_t* neighbors_indices_ptr = neighbors_indices->data(); const long long tsize = static_cast(neighbors_indices->size()/2); out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options); out = out.t(); return out.clone(); }