Commit aa9a3888 authored by Alexander Liao's avatar Alexander Liao
Browse files

knn cpu and multithreading support with testcases; positions of arguments

parent 3d682e5c
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
int64_t k, int64_t n_threads){
CHECK_CPU(query);
CHECK_CPU(support);
torch::Tensor out;
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, 0, dim, 0, n_threads, k, 0);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
long ind = batch[0];
long incr = 1;
for(unsigned long i=1; i < batch.size(); i++){
if(batch[i] == ind)
incr++;
else{
res[ind-batch[0]] = incr;
incr =1;
ind = batch[i];
}
}
res[ind-batch[0]] = incr;
}
torch::Tensor batch_knn_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
int64_t k) {
CHECK_CPU(query);
CHECK_CPU(support);
CHECK_CPU(query_batch);
CHECK_CPU(support_batch);
torch::Tensor out;
auto data_qb = query_batch.data_ptr<int64_t>();
auto data_sb = support_batch.data_ptr<int64_t>();
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
std::vector<long> size_query_batch_stl;
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
get_size_batch(query_batch_stl, size_query_batch_stl);
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
std::vector<long> size_support_batch_stl;
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
get_size_batch(support_batch_stl, size_support_batch_stl);
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
supports_stl,
size_query_batch_stl,
size_support_batch_stl,
neighbors_indices,
0,
dim,
0,
k, 0);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "utils/neighbors.cpp"
#include <iostream>
#include "compat.h"
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
int64_t k, int64_t n_threads);
torch::Tensor batch_knn_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
int64_t k);
\ No newline at end of file
......@@ -4,7 +4,7 @@
#include <cstdint>
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
torch::Tensor radius_cpu(torch::Tensor support, torch::Tensor query,
double radius, int64_t max_num, int64_t n_threads){
CHECK_CPU(query);
......@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
int dim = torch::size(query, 1);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads, 0, 1);
});
......@@ -36,7 +36,13 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
return out.clone();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
......@@ -58,10 +64,10 @@ void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
res[ind-batch[0]] = incr;
}
torch::Tensor batch_radius_cpu(torch::Tensor query,
torch::Tensor support,
torch::Tensor query_batch,
torch::Tensor batch_radius_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
double radius, int64_t max_num) {
CHECK_CPU(query);
......@@ -103,8 +109,8 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
neighbors_indices,
radius,
dim,
max_num
);
max_num,
0, 1);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
......@@ -114,5 +120,11 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
return out.clone();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
//#include "utils/neighbors.h"
#include "utils/neighbors.cpp"
#include <iostream>
#include "compat.h"
......
......@@ -3,6 +3,7 @@
#include <set>
#include <cstdint>
#include <thread>
#include <iostream>
typedef struct thread_struct {
void* kd_tree;
......@@ -15,6 +16,8 @@ typedef struct thread_struct {
size_t end;
double search_radius;
bool small;
bool option;
size_t k;
} thread_args;
template<typename scalar_t>
......@@ -37,7 +40,7 @@ void thread_routine(thread_args* targs) {
double search_radius = (double) targs->search_radius;
size_t start = targs->start;
size_t end = targs->end;
auto k = targs->k;
for (size_t i = start; i < end; i++) {
std::vector<scalar_t> p0 = *(((*pcd_query).pts)[i]);
......@@ -46,11 +49,23 @@ void thread_routine(thread_args* targs) {
std::copy(p0.begin(), p0.end(), query_pt);
(*matches)[i].reserve(*max_count);
std::vector<std::pair<size_t, scalar_t> > ret_matches;
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
tree_m->lock();
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams());
size_t nMatches;
if (targs->option){
nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams());
}
else {
nMatches = index->knnSearch(query_pt, k, &(*knn_ret_matches)[0],&(* knn_dist_matches)[0]);
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
}
ret_matches = *temp;
}
tree_m->unlock();
(*matches)[i] = ret_matches;
......@@ -67,7 +82,8 @@ void thread_routine(thread_args* targs) {
template<typename scalar_t>
size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
std::vector<size_t>*& neighbors_indices, double radius, int dim, int64_t max_num, int64_t n_threads){
std::vector<size_t>*& neighbors_indices, double radius, int dim,
int64_t max_num, int64_t n_threads, int64_t k, int option){
const scalar_t search_radius = static_cast<scalar_t>(radius*radius);
......@@ -120,9 +136,21 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
(*list_matches)[i0].reserve(*max_count);
std::vector<std::pair<size_t, scalar_t> > ret_matches;
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
size_t nMatches;
if (!!(option)){
nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
}
else {
nMatches = index->knnSearch(query_pt, (size_t)k, &(*knn_ret_matches)[0],&(* knn_dist_matches)[0]);
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
}
ret_matches = *temp;
}
(*list_matches)[i0] = ret_matches;
if(*max_count < nMatches) *max_count = nMatches;
i0++;
......@@ -171,6 +199,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
else {
targs->small = false;
}
targs->option = !!(option);
targs->k = k;
std::thread* temp = new std::thread(thread_routine<scalar_t>, targs);
tid[t] = temp;
}
......@@ -220,7 +250,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
std::vector<long>& q_batches,
std::vector<long>& s_batches,
std::vector<size_t>*& neighbors_indices,
double radius, int dim, int64_t max_num){
double radius, int dim, int64_t max_num, int64_t k, int option){
// indices
......@@ -292,14 +322,22 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
// Initial guess of neighbors size
all_inds_dists[i0].reserve(max_count);
// Find neighbors
size_t nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
// Update max count
std::vector<std::pair<size_t, float> > indices_dists;
nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
index->findNeighbors(resultSet, query_pt, search_params);
size_t nMatches;
if (!!option) {
nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
// Update max count
}
else {
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
nMatches = index->knnSearch(query_pt, (size_t)k, &(*knn_ret_matches)[0],&(*knn_dist_matches)[0]);
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
}
all_inds_dists[i0] = *temp;
}
if (nMatches > max_count)
max_count = nMatches;
// Increment query idx
......
......@@ -4,21 +4,57 @@
#ifdef WITH_CUDA
#include "cuda/knn_cuda.h"
#endif
#include "cpu/knn_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__knn(void) { return NULL; }
#endif
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine) {
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine, int64_t n_threads) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return knn_cuda(x, y, ptr_x, ptr_y, k, cosine);
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
else if (!(ptr_x.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = ptr_y.value();
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
AT_ERROR("No CPU version supported");
if (cosine) {
AT_ERROR("`cosine` argument not supported on CPU");
}
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
return knn_cpu(x,y,k,n_threads);
}
if (!(ptr_x.has_value())) {
auto batch_x = torch::zeros({torch::size(x,0)}).to(torch::kLong);
auto batch_y = ptr_y.value();
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::zeros({torch::size(y,0)}).to(torch::kLong);
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
}
......
......@@ -3,7 +3,7 @@ from itertools import product
import pytest
import torch
from torch_cluster import knn, knn_graph
import pickle
from .utils import grad_dtypes, devices, tensor
......@@ -57,3 +57,35 @@ def test_knn_graph(dtype, device):
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device):
d = pickle.load(open("test/knn_test_large.pkl", "rb"))
x = d['x'].to(device)
k = d['k']
truth = d['edges']
row, col = knn_graph(x, k=k, flow='source_to_target',
batch=None, n_threads=24)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
row, col = knn_graph(x, k=k, flow='source_to_target',
batch=None, n_threads=12)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
row, col = knn_graph(x, k=k, flow='source_to_target',
batch=None, n_threads=1)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
from typing import Optional
import torch
import scipy.spatial
import numpy as np
def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False) -> torch.Tensor:
cosine: bool = False, n_threads: int = 1) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
......@@ -44,9 +44,13 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
def is_sorted(x):
return (np.diff(x.detach().cpu()) >= 0).all()
if x.is_cuda:
if batch_x is not None:
assert x.size(0) == batch_x.numel()
assert is_sorted(batch_x)
batch_size = int(batch_x.max()) + 1
deg = x.new_zeros(batch_size, dtype=torch.long)
......@@ -59,6 +63,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
if batch_y is not None:
assert y.size(0) == batch_y.numel()
assert is_sorted(batch_y)
batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long)
......@@ -69,51 +74,32 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device)
return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine)
return torch.ops.torch_cluster.knn(x, y, ptr_x,
ptr_y, k, cosine, n_threads)
else:
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.dim() == 2
if batch_x is not None:
assert batch_x.dim() == 1
assert is_sorted(batch_x)
assert x.size(0) == batch_x.size(0)
assert y.dim() == 2
if batch_y is not None:
assert batch_y.dim() == 1
assert is_sorted(batch_y)
assert y.size(0) == batch_y.size(0)
assert x.size(1) == y.size(1)
if cosine:
raise NotImplementedError('`cosine` argument not supported on CPU')
# Translate and rescale x and y to [0, 1].
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy)
y.div_(max_xy)
# Concat batch/features to ensure no cross-links between examples.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(y.detach().cpu(), k=k,
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long)
row = row.view(-1, 1).repeat(1, k)
mask = ~torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0)
return torch.ops.torch_cluster.knn(x, y, batch_x, batch_y,
k, cosine, n_threads)
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False) -> torch.Tensor:
cosine: bool = False, n_threads: int = 1) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
......@@ -145,7 +131,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine=cosine)
row, col = knn(x, x, k if loop else k + 1, batch, batch,
cosine=cosine, n_threads=n_threads)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
mask = row != col
......
......@@ -75,7 +75,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
result = torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, n_threads)
else:
assert x.dim() == 2
if batch_x is not None:
assert batch_x.dim() == 1
......@@ -136,12 +135,7 @@ def radius_graph(x: torch.Tensor, r: float,
row, col = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
n_threads)
if x.is_cuda:
row, col = (col, row) if flow == 'source_to_target' else (row, col)
else:
row, col = (col, row) if flow == 'target_to_source' else (row, col)
if not loop:
mask = row != col
row, col = row[mask], col[mask]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment