Commit aa9a3888 authored by Alexander Liao's avatar Alexander Liao
Browse files

knn cpu and multithreading support with testcases; positions of arguments

parent 3d682e5c
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
int64_t k, int64_t n_threads){
CHECK_CPU(query);
CHECK_CPU(support);
torch::Tensor out;
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, 0, dim, 0, n_threads, k, 0);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
long ind = batch[0];
long incr = 1;
for(unsigned long i=1; i < batch.size(); i++){
if(batch[i] == ind)
incr++;
else{
res[ind-batch[0]] = incr;
incr =1;
ind = batch[i];
}
}
res[ind-batch[0]] = incr;
}
torch::Tensor batch_knn_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
int64_t k) {
CHECK_CPU(query);
CHECK_CPU(support);
CHECK_CPU(query_batch);
CHECK_CPU(support_batch);
torch::Tensor out;
auto data_qb = query_batch.data_ptr<int64_t>();
auto data_sb = support_batch.data_ptr<int64_t>();
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
std::vector<long> size_query_batch_stl;
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
get_size_batch(query_batch_stl, size_query_batch_stl);
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
std::vector<long> size_support_batch_stl;
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
get_size_batch(support_batch_stl, size_support_batch_stl);
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
data_s + support.size(0)*support.size(1));
int dim = torch::size(query, 1);
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
supports_stl,
size_query_batch_stl,
size_support_batch_stl,
neighbors_indices,
0,
dim,
0,
k, 0);
});
size_t* neighbors_indices_ptr = neighbors_indices->data();
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t();
auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "utils/neighbors.cpp"
#include <iostream>
#include "compat.h"
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
int64_t k, int64_t n_threads);
torch::Tensor batch_knn_cpu(torch::Tensor support,
torch::Tensor query,
torch::Tensor support_batch,
torch::Tensor query_batch,
int64_t k);
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <cstdint> #include <cstdint>
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, torch::Tensor radius_cpu(torch::Tensor support, torch::Tensor query,
double radius, int64_t max_num, int64_t n_threads){ double radius, int64_t max_num, int64_t n_threads){
CHECK_CPU(query); CHECK_CPU(query);
...@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, ...@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
int dim = torch::size(query, 1); int dim = torch::size(query, 1);
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads); max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads, 0, 1);
}); });
...@@ -36,7 +36,13 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, ...@@ -36,7 +36,13 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options); out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t(); out = out.t();
return out.clone(); auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
} }
...@@ -58,10 +64,10 @@ void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){ ...@@ -58,10 +64,10 @@ void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
res[ind-batch[0]] = incr; res[ind-batch[0]] = incr;
} }
torch::Tensor batch_radius_cpu(torch::Tensor query, torch::Tensor batch_radius_cpu(torch::Tensor support,
torch::Tensor support, torch::Tensor query,
torch::Tensor query_batch,
torch::Tensor support_batch, torch::Tensor support_batch,
torch::Tensor query_batch,
double radius, int64_t max_num) { double radius, int64_t max_num) {
CHECK_CPU(query); CHECK_CPU(query);
...@@ -103,8 +109,8 @@ torch::Tensor batch_radius_cpu(torch::Tensor query, ...@@ -103,8 +109,8 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
neighbors_indices, neighbors_indices,
radius, radius,
dim, dim,
max_num max_num,
); 0, 1);
}); });
size_t* neighbors_indices_ptr = neighbors_indices->data(); size_t* neighbors_indices_ptr = neighbors_indices->data();
...@@ -114,5 +120,11 @@ torch::Tensor batch_radius_cpu(torch::Tensor query, ...@@ -114,5 +120,11 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options); out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t(); out = out.t();
return out.clone(); auto result = torch::zeros_like(out);
auto index = torch::tensor({1,0});
result.index_copy_(0, index, out);
return result;
} }
\ No newline at end of file
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
//#include "utils/neighbors.h"
#include "utils/neighbors.cpp" #include "utils/neighbors.cpp"
#include <iostream> #include <iostream>
#include "compat.h" #include "compat.h"
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <set> #include <set>
#include <cstdint> #include <cstdint>
#include <thread> #include <thread>
#include <iostream>
typedef struct thread_struct { typedef struct thread_struct {
void* kd_tree; void* kd_tree;
...@@ -15,6 +16,8 @@ typedef struct thread_struct { ...@@ -15,6 +16,8 @@ typedef struct thread_struct {
size_t end; size_t end;
double search_radius; double search_radius;
bool small; bool small;
bool option;
size_t k;
} thread_args; } thread_args;
template<typename scalar_t> template<typename scalar_t>
...@@ -37,7 +40,7 @@ void thread_routine(thread_args* targs) { ...@@ -37,7 +40,7 @@ void thread_routine(thread_args* targs) {
double search_radius = (double) targs->search_radius; double search_radius = (double) targs->search_radius;
size_t start = targs->start; size_t start = targs->start;
size_t end = targs->end; size_t end = targs->end;
auto k = targs->k;
for (size_t i = start; i < end; i++) { for (size_t i = start; i < end; i++) {
std::vector<scalar_t> p0 = *(((*pcd_query).pts)[i]); std::vector<scalar_t> p0 = *(((*pcd_query).pts)[i]);
...@@ -46,11 +49,23 @@ void thread_routine(thread_args* targs) { ...@@ -46,11 +49,23 @@ void thread_routine(thread_args* targs) {
std::copy(p0.begin(), p0.end(), query_pt); std::copy(p0.begin(), p0.end(), query_pt);
(*matches)[i].reserve(*max_count); (*matches)[i].reserve(*max_count);
std::vector<std::pair<size_t, scalar_t> > ret_matches; std::vector<std::pair<size_t, scalar_t> > ret_matches;
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
tree_m->lock(); tree_m->lock();
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams()); size_t nMatches;
if (targs->option){
nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams());
}
else {
nMatches = index->knnSearch(query_pt, k, &(*knn_ret_matches)[0],&(* knn_dist_matches)[0]);
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
}
ret_matches = *temp;
}
tree_m->unlock(); tree_m->unlock();
(*matches)[i] = ret_matches; (*matches)[i] = ret_matches;
...@@ -67,7 +82,8 @@ void thread_routine(thread_args* targs) { ...@@ -67,7 +82,8 @@ void thread_routine(thread_args* targs) {
template<typename scalar_t> template<typename scalar_t>
size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports, size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
std::vector<size_t>*& neighbors_indices, double radius, int dim, int64_t max_num, int64_t n_threads){ std::vector<size_t>*& neighbors_indices, double radius, int dim,
int64_t max_num, int64_t n_threads, int64_t k, int option){
const scalar_t search_radius = static_cast<scalar_t>(radius*radius); const scalar_t search_radius = static_cast<scalar_t>(radius*radius);
...@@ -120,9 +136,21 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t> ...@@ -120,9 +136,21 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
(*list_matches)[i0].reserve(*max_count); (*list_matches)[i0].reserve(*max_count);
std::vector<std::pair<size_t, scalar_t> > ret_matches; std::vector<std::pair<size_t, scalar_t> > ret_matches;
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params); size_t nMatches;
if (!!(option)){
nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
}
else {
nMatches = index->knnSearch(query_pt, (size_t)k, &(*knn_ret_matches)[0],&(* knn_dist_matches)[0]);
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
}
ret_matches = *temp;
}
(*list_matches)[i0] = ret_matches; (*list_matches)[i0] = ret_matches;
if(*max_count < nMatches) *max_count = nMatches; if(*max_count < nMatches) *max_count = nMatches;
i0++; i0++;
...@@ -171,6 +199,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t> ...@@ -171,6 +199,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
else { else {
targs->small = false; targs->small = false;
} }
targs->option = !!(option);
targs->k = k;
std::thread* temp = new std::thread(thread_routine<scalar_t>, targs); std::thread* temp = new std::thread(thread_routine<scalar_t>, targs);
tid[t] = temp; tid[t] = temp;
} }
...@@ -220,7 +250,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries, ...@@ -220,7 +250,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
std::vector<long>& q_batches, std::vector<long>& q_batches,
std::vector<long>& s_batches, std::vector<long>& s_batches,
std::vector<size_t>*& neighbors_indices, std::vector<size_t>*& neighbors_indices,
double radius, int dim, int64_t max_num){ double radius, int dim, int64_t max_num, int64_t k, int option){
// indices // indices
...@@ -292,14 +322,22 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries, ...@@ -292,14 +322,22 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
// Initial guess of neighbors size // Initial guess of neighbors size
all_inds_dists[i0].reserve(max_count); all_inds_dists[i0].reserve(max_count);
// Find neighbors // Find neighbors
size_t nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
// Update max count
std::vector<std::pair<size_t, float> > indices_dists;
nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
index->findNeighbors(resultSet, query_pt, search_params);
size_t nMatches;
if (!!option) {
nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
// Update max count
}
else {
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
nMatches = index->knnSearch(query_pt, (size_t)k, &(*knn_ret_matches)[0],&(*knn_dist_matches)[0]);
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
}
all_inds_dists[i0] = *temp;
}
if (nMatches > max_count) if (nMatches > max_count)
max_count = nMatches; max_count = nMatches;
// Increment query idx // Increment query idx
......
...@@ -4,21 +4,57 @@ ...@@ -4,21 +4,57 @@
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/knn_cuda.h" #include "cuda/knn_cuda.h"
#endif #endif
#include "cpu/knn_cpu.h"
#ifdef _WIN32 #ifdef _WIN32
PyMODINIT_FUNC PyInit__knn(void) { return NULL; } PyMODINIT_FUNC PyInit__knn(void) { return NULL; }
#endif #endif
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::optional<torch::Tensor> ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine) { torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine, int64_t n_threads) {
if (x.device().is_cuda()) { if (x.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return knn_cuda(x, y, ptr_x, ptr_y, k, cosine); if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
else if (!(ptr_x.has_value())) {
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
auto batch_y = ptr_y.value();
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return knn_cuda(x, y, batch_x, batch_y, k, cosine);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
AT_ERROR("No CPU version supported"); if (cosine) {
AT_ERROR("`cosine` argument not supported on CPU");
}
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
return knn_cpu(x,y,k,n_threads);
}
if (!(ptr_x.has_value())) {
auto batch_x = torch::zeros({torch::size(x,0)}).to(torch::kLong);
auto batch_y = ptr_y.value();
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
else if (!(ptr_y.has_value())) {
auto batch_x = ptr_x.value();
auto batch_y = torch::zeros({torch::size(y,0)}).to(torch::kLong);
return batch_knn_cpu(x, y, batch_x, batch_y, k);
}
auto batch_x = ptr_x.value();
auto batch_y = ptr_y.value();
return batch_knn_cpu(x, y, batch_x, batch_y, k);
} }
} }
......
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_cluster import knn, knn_graph from torch_cluster import knn, knn_graph
import pickle
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
...@@ -57,3 +57,35 @@ def test_knn_graph(dtype, device): ...@@ -57,3 +57,35 @@ def test_knn_graph(dtype, device):
row = row.view(-1, 2).sort(dim=-1)[0].view(-1) row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device):
d = pickle.load(open("test/knn_test_large.pkl", "rb"))
x = d['x'].to(device)
k = d['k']
truth = d['edges']
row, col = knn_graph(x, k=k, flow='source_to_target',
batch=None, n_threads=24)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
row, col = knn_graph(x, k=k, flow='source_to_target',
batch=None, n_threads=12)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
row, col = knn_graph(x, k=k, flow='source_to_target',
batch=None, n_threads=1)
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
list(col.cpu().numpy()))])
assert(truth == edges)
from typing import Optional from typing import Optional
import torch import torch
import scipy.spatial import numpy as np
def knn(x: torch.Tensor, y: torch.Tensor, k: int, def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None, batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, batch_y: Optional[torch.Tensor] = None,
cosine: bool = False) -> torch.Tensor: cosine: bool = False, n_threads: int = 1) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`. :obj:`x`.
...@@ -44,9 +44,13 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -44,9 +44,13 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
def is_sorted(x):
return (np.diff(x.detach().cpu()) >= 0).all()
if x.is_cuda: if x.is_cuda:
if batch_x is not None: if batch_x is not None:
assert x.size(0) == batch_x.numel() assert x.size(0) == batch_x.numel()
assert is_sorted(batch_x)
batch_size = int(batch_x.max()) + 1 batch_size = int(batch_x.max()) + 1
deg = x.new_zeros(batch_size, dtype=torch.long) deg = x.new_zeros(batch_size, dtype=torch.long)
...@@ -59,6 +63,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -59,6 +63,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
assert is_sorted(batch_y)
batch_size = int(batch_y.max()) + 1 batch_size = int(batch_y.max()) + 1
deg = y.new_zeros(batch_size, dtype=torch.long) deg = y.new_zeros(batch_size, dtype=torch.long)
...@@ -69,51 +74,32 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -69,51 +74,32 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
else: else:
ptr_y = torch.tensor([0, y.size(0)], device=y.device) ptr_y = torch.tensor([0, y.size(0)], device=y.device)
return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine) return torch.ops.torch_cluster.knn(x, y, ptr_x,
ptr_y, k, cosine, n_threads)
else: else:
if batch_x is None: assert x.dim() == 2
batch_x = x.new_zeros(x.size(0), dtype=torch.long) if batch_x is not None:
assert batch_x.dim() == 1
if batch_y is None: assert is_sorted(batch_x)
batch_y = y.new_zeros(y.size(0), dtype=torch.long) assert x.size(0) == batch_x.size(0)
assert x.dim() == 2 and batch_x.dim() == 1 assert y.dim() == 2
assert y.dim() == 2 and batch_y.dim() == 1 if batch_y is not None:
assert batch_y.dim() == 1
assert is_sorted(batch_y)
assert y.size(0) == batch_y.size(0)
assert x.size(1) == y.size(1) assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
if cosine: if cosine:
raise NotImplementedError('`cosine` argument not supported on CPU') raise NotImplementedError('`cosine` argument not supported on CPU')
# Translate and rescale x and y to [0, 1]. return torch.ops.torch_cluster.knn(x, y, batch_x, batch_y,
min_xy = min(x.min().item(), y.min().item()) k, cosine, n_threads)
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy)
y.div_(max_xy)
# Concat batch/features to ensure no cross-links between examples.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(y.detach().cpu(), k=k,
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long)
row = row.view(-1, 1).repeat(1, k)
mask = ~torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0)
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target', loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False) -> torch.Tensor: cosine: bool = False, n_threads: int = 1) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points. r"""Computes graph edges to the nearest :obj:`k` points.
Args: Args:
...@@ -145,7 +131,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, ...@@ -145,7 +131,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
""" """
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine=cosine) row, col = knn(x, x, k if loop else k + 1, batch, batch,
cosine=cosine, n_threads=n_threads)
row, col = (col, row) if flow == 'source_to_target' else (row, col) row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop: if not loop:
mask = row != col mask = row != col
......
...@@ -75,7 +75,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -75,7 +75,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
result = torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r, result = torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, n_threads) max_num_neighbors, n_threads)
else: else:
assert x.dim() == 2 assert x.dim() == 2
if batch_x is not None: if batch_x is not None:
assert batch_x.dim() == 1 assert batch_x.dim() == 1
...@@ -136,12 +135,7 @@ def radius_graph(x: torch.Tensor, r: float, ...@@ -136,12 +135,7 @@ def radius_graph(x: torch.Tensor, r: float,
row, col = radius(x, x, r, batch, batch, row, col = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1, max_num_neighbors if loop else max_num_neighbors + 1,
n_threads) n_threads)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if x.is_cuda:
row, col = (col, row) if flow == 'source_to_target' else (row, col)
else:
row, col = (col, row) if flow == 'target_to_source' else (row, col)
if not loop: if not loop:
mask = row != col mask = row != col
row, col = row[mask], col[mask] row, col = row[mask], col[mask]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment