Commit c987d065 authored by rusty1s's avatar rusty1s
Browse files

k-nearest neighbor

parent bea734f1
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y);
at::Tensor knn(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return knn_cuda(x, y, k, batch_x, batch_y);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("knn", &knn, "k-Nearest Neighbor (CUDA)");
}
#include <ATen/ATen.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y, scalar_t *__restrict__ dist,
int64_t *__restrict__ row, int64_t *__restrict__ col, size_t k,
size_t dim) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t start_idx_x = batch_x[batch_idx];
const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];
const ptrdiff_t start_idx_y = batch_y[batch_idx];
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
for (ptrdiff_t n_y = start_idx_y + idx; n_y < end_idx_y; n_y += THREADS) {
for (ptrdiff_t k_idx = 0; k_idx < k; k_idx++) {
row[n_y * k + k_idx] = n_y;
}
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {
scalar_t tmp_dist = 0;
for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
if (dist[n_y * k + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) {
dist[n_y * k + k_idx_2] = dist[n_y * k + k_idx_2 - 1];
col[n_y * k + k_idx_2] = col[n_y * k + k_idx_2 - 1];
}
dist[n_y * k + k_idx_1] = tmp_dist;
col[n_y * k + k_idx_1] = n_x;
break;
}
}
}
}
}
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) {
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;
batch_x = degree(batch_x, batch_size);
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
auto dist = at::full(y.size(0) * k, 1e38, y.options());
auto row = at::empty(y.size(0) * k, batch_y.options());
auto col = at::empty(y.size(0) * k, batch_y.options());
AT_DISPATCH_FLOATING_TYPES(x.type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<batch_size, THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
batch_y.data<int64_t>(), dist.data<scalar_t>(), row.data<int64_t>(),
col.data<int64_t>(), k, x.size(1));
});
return at::stack({row, col}, 0);
}
......@@ -30,7 +30,6 @@ __global__ void nearest_kernel(const scalar_t *__restrict__ x,
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
dist = sqrt(dist);
if (dist < best) {
best = dist;
......
#include <torch/torch.h>
#define IS_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor");
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
#define CHECK_INPUT(x) IS_CUDA(x) IS_CONTIGUOUS(x)
std::vector<at::Tensor> query_radius_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
double radius,
int max_num_neighbors,
bool include_self);
std::vector<at::Tensor> query_radius(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
double radius,
int max_num_neighbors,
bool include_self) {
CHECK_INPUT(batch_slices);
CHECK_INPUT(query_batch_slices);
CHECK_INPUT(pos);
CHECK_INPUT(query_pos);
return query_radius_cuda(batch_size, batch_slices, query_batch_slices, pos, query_pos, radius, max_num_neighbors, include_self);
}
std::vector<at::Tensor> query_knn_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
int num_neighbors,
bool include_self);
std::vector<at::Tensor> query_knn(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
int num_neighbors,
bool include_self) {
CHECK_INPUT(batch_slices);
CHECK_INPUT(query_batch_slices);
CHECK_INPUT(pos);
CHECK_INPUT(query_pos);
return query_knn_cuda(batch_size, batch_slices, query_batch_slices, pos, query_pos, num_neighbors, include_self);
}
at::Tensor farthest_point_sampling_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor pos,
int num_sample,
at::Tensor start_points);
at::Tensor farthest_point_sampling(
int batch_size,
at::Tensor batch_slices,
at::Tensor pos,
int num_sample,
at::Tensor start_points) {
CHECK_INPUT(batch_slices);
CHECK_INPUT(pos);
CHECK_INPUT(start_points);
return farthest_point_sampling_cuda(batch_size, batch_slices, pos, num_sample, start_points);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("query_radius", &query_radius, "Query Radius (CUDA)");
m.def("query_knn", &query_knn, "Query K-Nearest neighbor (CUDA)");
m.def("farthest_point_sampling", &farthest_point_sampling, "Farthest Point Sampling (CUDA)");
}
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdint.h>
#include <vector>
#define THREADS 1024
// Original by Qi. et al (https://github.com/charlesq34/pointnet2)
template <typename scalar_t>
__global__ void query_radius_cuda_kernel(
const int64_t* __restrict__ batch_slices,
const int64_t* __restrict__ query_batch_slices,
const scalar_t* __restrict__ pos,
const scalar_t* __restrict__ query_pos,
const scalar_t radius,
const int64_t max_num_neighbors,
const bool include_self,
int64_t* idx_output,
int64_t* cnt_output)
{
const int64_t batch_index = blockIdx.x;
const int64_t index = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t batch_start = batch_slices[2*batch_index];
const int64_t query_batch_start = query_batch_slices[2*batch_index];
const int64_t batch_end = batch_slices[2*batch_index+1];
const int64_t query_batch_end = query_batch_slices[2*batch_index+1];
const int64_t batch_size = batch_end - batch_start + 1;
const int64_t query_batch_size = query_batch_end - query_batch_start + 1;
pos += batch_start * 3;
query_pos += query_batch_start * 3;
idx_output += query_batch_start * max_num_neighbors;
cnt_output += query_batch_start;
for (int64_t j = index; j < query_batch_size; j+=stride){
int64_t cnt = 0;
scalar_t x2=query_pos[j*3+0];
scalar_t y2=query_pos[j*3+1];
scalar_t z2=query_pos[j*3+2];
// dummy outputs initialisation with value -1
if (cnt==0) {
for (int64_t l = 0;l < max_num_neighbors; l++)
idx_output[j*max_num_neighbors+l] = -1;
}
for (int64_t k = 0; k < batch_size; k++) {
if (cnt == max_num_neighbors)
break;
scalar_t x1=pos[k*3+0];
scalar_t y1=pos[k*3+1];
scalar_t z1=pos[k*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
if (d <= radius && (d > 0 || include_self)) {
idx_output[j * max_num_neighbors + cnt] = batch_start + k;
cnt+=1;
}
}
cnt_output[j] = cnt;
}
}
template <typename scalar_t>
__global__ void query_knn_cuda_kernel(
const int64_t* __restrict__ batch_slices,
const int64_t* __restrict__ query_batch_slices,
const scalar_t* __restrict__ pos,
const scalar_t* __restrict__ query_pos,
const int64_t num_neighbors,
const bool include_self,
scalar_t* tmp_dists,
int64_t* idx_output){
const int64_t batch_index = blockIdx.x;
const int64_t index = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t batch_start = batch_slices[2*batch_index];
const int64_t query_batch_start = query_batch_slices[2*batch_index];
const int64_t batch_end = batch_slices[2*batch_index+1];
const int64_t query_batch_end = query_batch_slices[2*batch_index+1];
const int64_t batch_size = batch_end - batch_start + 1;
const int64_t query_batch_size = query_batch_end - query_batch_start + 1;
pos += batch_start * 3;
query_pos += query_batch_start * 3;
idx_output += query_batch_start * num_neighbors;
tmp_dists += query_batch_start * num_neighbors;
for (int64_t j = index; j < query_batch_size; j += stride){
scalar_t x2=query_pos[j*3+0];
scalar_t y2=query_pos[j*3+1];
scalar_t z2=query_pos[j*3+2];
// reset to dummy values
for (int64_t l = 0; l < num_neighbors; l++){
idx_output[j * num_neighbors + l] = -1;
tmp_dists[j * num_neighbors + l] = 2147483647;
}
for (int64_t k = 0; k < batch_size; k++) {
scalar_t x1=pos[k*3+0];
scalar_t y1=pos[k*3+1];
scalar_t z1=pos[k*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
if (d > 0 || include_self){
for (int64_t i = 0; i < num_neighbors; i++){
if (tmp_dists[j * num_neighbors + i] > d){
for (int64_t i2 = num_neighbors-1; i2 > i; i2--){
tmp_dists[j * num_neighbors + i2] = tmp_dists[j * num_neighbors + i2 - 1];
idx_output[j * num_neighbors + i2] = idx_output[j * num_neighbors + i2 - 1];
}
tmp_dists[j * num_neighbors + i] = d;
idx_output[j * num_neighbors + i] = batch_start + k;
break;
}
}
}
}
}
}
template <typename scalar_t>
__global__ void farthest_point_sampling_kernel(
const int64_t* __restrict__ batch_slices,
const scalar_t* __restrict__ pos,
const int64_t num_sample,
const int64_t* __restrict__ start_points,
scalar_t* tmp_dists,
int64_t* idx_output){
const int64_t batch_index = blockIdx.x;
const int64_t index = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t batch_start = batch_slices[2*batch_index];
const int64_t batch_end = batch_slices[2*batch_index+1];
const int64_t batch_size = batch_end - batch_start + 1;
__shared__ scalar_t dists[THREADS];
__shared__ int64_t dists_i[THREADS];
pos += batch_start * 3;
idx_output += num_sample * batch_index;
tmp_dists += batch_start;
int64_t old = start_points[batch_index];
// explicitly handle the case where less than num_sample points are available to sample from
if (index == 0){
idx_output[0] = batch_start + old;
if (batch_size < num_sample){
for (int64_t i = 0; i < batch_size; i++){
idx_output[i] = batch_start + i;
}
for (int64_t i = batch_size; i < num_sample; i++){
idx_output[i] = -1;
}
}
}
if (batch_size < num_sample){
return;
}
// initialise temporary distances as big as possible
for (int64_t j = index; j < batch_size; j+=stride){
tmp_dists[j] = 2147483647;
}
__syncthreads();
for (int64_t i = 1; i < num_sample; i++){
int64_t besti = -1;
scalar_t best = -1;
// compute distance from last point to all others and update using the minimum of already computed distances
for (int64_t j = index; j < batch_size; j+= stride){
scalar_t td = tmp_dists[j];
scalar_t x1 = pos[old * 3 + 0];
scalar_t y1 = pos[old * 3 + 1];
scalar_t z1 = pos[old * 3 + 2];
scalar_t x2 = pos[j * 3 + 0];
scalar_t y2 = pos[j * 3 + 1];
scalar_t z2 = pos[j * 3 + 2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
scalar_t d2 = min(d, tmp_dists[j]);
if (td != d2){
tmp_dists[j] = d2;
}
if (tmp_dists[j] > best){
best = tmp_dists[j];
besti = j;
}
}
// sort best indices
dists[index] = best;
dists_i[index] = besti;
__syncthreads();
// get the maximum distances (by merging)
for (int64_t u = 0; (1<<u) < blockDim.x ; u++){
__syncthreads();
if (index < (blockDim.x >> (u+1))){
int64_t i1 = (index*2)<<u;
int64_t i2 = (index*2+1)<<u;
if (dists[i1] < dists[i2]){
dists[i1] = dists[i2];
dists_i[i1] = dists_i[i2];
}
}
}
__syncthreads();
if (dists[0] == 0){
break;
}
// thread 0 collects in output
old = dists_i[0];
if (index == 0){
idx_output[i] = batch_start + old;
}
}
}
std::vector<at::Tensor> query_radius_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
double radius,
int max_num_neighbors,
bool include_self) {
const auto num_points = query_pos.size(0);
auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {num_points, max_num_neighbors});
auto cnt_output = at::empty(pos.type().toScalarType(at::kLong), {num_points});
AT_DISPATCH_FLOATING_TYPES(pos.type(), "query_radius_cuda_kernel", [&] {
query_radius_cuda_kernel<scalar_t><<<batch_size, THREADS>>>(
batch_slices.data<int64_t>(),
query_batch_slices.data<int64_t>(),
pos.data<scalar_t>(),
query_pos.data<scalar_t>(),
(scalar_t) radius*radius,
max_num_neighbors,
include_self,
idx_output.data<int64_t>(),
cnt_output.data<int64_t>());
});
return {idx_output, cnt_output};
}
std::vector<at::Tensor> query_knn_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
int num_neighbors,
bool include_self) {
const auto num_points = query_pos.size(0);
auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {num_points, num_neighbors});
auto dists = at::empty(pos.type(), {num_points, num_neighbors});
AT_DISPATCH_FLOATING_TYPES(pos.type(), "query_knn_cuda_kernel", [&] {
query_knn_cuda_kernel<scalar_t><<<batch_size, THREADS>>>(
batch_slices.data<int64_t>(),
query_batch_slices.data<int64_t>(),
pos.data<scalar_t>(),
query_pos.data<scalar_t>(),
num_neighbors,
include_self,
dists.data<scalar_t>(),
idx_output.data<int64_t>());
});
return {idx_output, dists};
}
at::Tensor farthest_point_sampling_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor pos,
int num_sample,
at::Tensor start_points) {
auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {batch_size * num_sample});
auto tmp_dists = at::empty(pos.type(), {pos.size(0)});
AT_DISPATCH_FLOATING_TYPES(pos.type(), "farthest_point_sampling_kernel", [&] {
farthest_point_sampling_kernel<scalar_t><<<batch_size, THREADS>>>(
batch_slices.data<int64_t>(),
pos.data<scalar_t>(),
num_sample,
start_points.data<int64_t>(),
tmp_dists.data<scalar_t>(),
idx_output.data<int64_t>());
});
return idx_output;
}
......@@ -16,6 +16,7 @@ if torch.cuda.is_available():
CUDAExtension('fps_cuda', ['cuda/fps.cpp', 'cuda/fps_kernel.cu']),
CUDAExtension('nearest_cuda',
['cuda/nearest.cpp', 'cuda/nearest_kernel.cu']),
CUDAExtension('knn_cuda', ['cuda/knn.cpp', 'cuda/knn_kernel.cu']),
CUDAExtension('radius_cuda',
['cuda/radius.cpp', 'cuda/radius_kernel.cu']),
]
......
from itertools import product
import pytest
import torch
from torch_cluster import knn
from .utils import tensor
devices = [torch.device('cuda')]
grad_dtypes = [torch.float]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
y = tensor([
[1, 0],
[-1, 0],
], dtype, device)
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)
out = knn(x, y, 2, batch_x, batch_y)
assert out.tolist() == [[0, 0, 1, 1], [2, 3, 4, 5]]
......@@ -4,10 +4,9 @@ import pytest
import torch
from torch_cluster import radius
from .utils import tensor
from .utils import tensor, grad_dtypes
devices = [torch.device('cuda')]
grad_dtypes = [torch.float]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
......@@ -32,6 +31,4 @@ def test_radius(dtype, device):
batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
print()
print([[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]])
print(out)
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
import pytest
import torch
import numpy as np
from torch_cluster.sample import (sample_farthest, batch_slices,
radius_query_edges)
from .utils import tensor, grad_dtypes, devices
@pytest.mark.parametrize('device', devices)
def test_batch_slices(device):
# test sample case for correctness
batch = tensor(
[0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device)
slices, sizes = batch_slices(batch, sizes=True)
slices, sizes = slices.cpu().tolist(), sizes.cpu().tolist()
assert slices == [0, 99, 100, 149, 150, 191]
assert sizes == [100, 50, 42]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype', grad_dtypes)
def test_fps(dtype):
# test simple case for correctness
batch = [0] * 10
points = [[-1, -1, 0], [-1, 1, 0], [1, 1, 0], [1, -1, 0]]
random_points = np.random.uniform(-1, 1, size=(6, 3))
random_points[:, 2] = 0
random_points = random_points.tolist()
batch = tensor(batch, dtype=torch.long, device='cuda')
pos = tensor(points + random_points, dtype=dtype, device='cuda')
sample_farthest(batch, pos, num_sampled=4, index=True)
# needs update since isin is missing (sort indices, then compare?)
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'),
# False).all().cpu().item() == 1
# test variable number of points for each element in a batch
batch = [0] * 100 + [1] * 50
points1 = np.random.uniform(-1, 1, size=(100, 3)).tolist()
points2 = np.random.uniform(-1, 1, size=(50, 3)).tolist()
batch = tensor(batch, dtype=torch.long, device='cuda')
pos = tensor(points1 + points2, dtype=dtype, device='cuda')
mask = sample_farthest(batch, pos, num_sampled=75, index=False)
assert mask[batch == 0].sum().item() == 75
assert mask[batch == 1].sum().item() == 50
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype', grad_dtypes)
def test_radius_edges(dtype):
batch = [0] * 100 + [1] * 50 + [2] * 42
points = np.random.uniform(-1, 1, size=(192, 3))
query_batch = [0] * 10 + [1] * 15 + [2] * 20
query_points = np.random.uniform(-1, 1, size=(45, 3))
radius = 0.5
batch = tensor(batch, dtype=torch.long, device='cuda')
query_batch = tensor(query_batch, dtype=torch.long, device='cuda')
pos = tensor(points, dtype=dtype, device='cuda')
query_pos = tensor(query_points, dtype=dtype, device='cuda')
edge_index = radius_query_edges(
batch,
pos,
query_batch,
query_pos,
radius=radius,
max_num_neighbors=128)
row, col = edge_index
dist = torch.norm(pos[col] - query_pos[row], p=2, dim=1)
assert (dist <= radius).all().item()
......@@ -2,6 +2,7 @@ from .graclus import graclus_cluster
from .grid import grid_cluster
from .fps import fps
from .nearest import nearest
from .knn import knn, knn_graph
from .radius import radius, radius_graph
__version__ = '1.2.0'
......@@ -11,6 +12,8 @@ __all__ = [
'grid_cluster',
'fps',
'nearest',
'knn',
'knn_graph',
'radius',
'radius_graph',
'__version__',
......
import torch
if torch.cuda.is_available():
import knn_cuda
def knn(x, y, k, batch_x=None, batch_y=None):
"""Finds for each element in `y` the `k` nearest points in `x`.
Args:
x (Tensor): D-dimensional point features.
y (Tensor): D-dimensional point features.
k (int): The number of neighbors.
batch_x (LongTensor, optional): Vector that maps each point to its
example identifier. If :obj:`None`, all points belong to the same
example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`)
batch_y (LongTensor, optional): See `batch_x` (default: :obj:`None`)
:rtype: :class:`LongTensor`
Examples::
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
>>> batch_x = torch.Tensor([0, 0, 0, 0])
>>> y = torch.Tensor([[-1, 0], [1, 0]])
>>> batch_x = torch.Tensor([0, 0])
>>> out = knn(x, y, 2, batch_x, batch_y)
"""
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.is_cuda
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
op = knn_cuda.knn if x.is_cuda else None
assign_index = op(x, y, k, batch_x, batch_y)
return assign_index
def knn_graph(x, k, batch=None):
"""Finds for each element in `x` the `k` nearest points.
Args:
x (Tensor): D-dimensional point features.
k (int): The number of neighbors.
batch (LongTensor, optional): Vector that maps each point to its
example identifier. If :obj:`None`, all points belong to the same
example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`)
:rtype: :class:`LongTensor`
Examples::
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
>>> batch = torch.Tensor([0, 0, 0, 0])
>>> out = knn_graph(x, 2, batch)
"""
edge_index = knn(x, x, k + 1, batch, batch)
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
import torch
if torch.cuda.is_available():
from sample_cuda import farthest_point_sampling, query_radius, query_knn
def batch_slices(batch, sizes=False, include_ends=True):
"""
Calculates size, start and end indices for each element in a batch.
"""
batch_size = batch.max().item() + 1
size = batch.new_zeros(batch_size).scatter_add_(0, batch,
torch.ones_like(batch))
cumsum = torch.cumsum(size, dim=0)
starts = cumsum - size
ends = cumsum - 1
slices = starts
if include_ends:
slices = torch.stack([starts, ends], dim=1).view(-1)
if sizes:
return slices, size
return slices
def sample_farthest(batch, pos, num_sampled, random_start=False, index=False):
"""Samples a specified number of points for each element in a batch using
farthest iterative point sampling and returns a mask (or indices) for the
sampled points. If there are less than num_sampled points in a point cloud
all points are returned.
"""
if not pos.is_cuda or not batch.is_cuda:
raise NotImplementedError
assert pos.is_contiguous() and batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
if random_start:
random = torch.rand(batch_size, device=slices.device)
start_points = (sizes.float() * random).long()
else:
start_points = torch.zeros_like(sizes)
idx = farthest_point_sampling(batch_size, slices, pos, num_sampled,
start_points)
# Remove invalid indices
idx = idx[idx != -1]
if index:
return idx
mask = torch.zeros(pos.size(0), dtype=torch.uint8, device=pos.device)
mask[idx] = 1
return mask
def radius_query_edges(batch,
pos,
query_batch,
query_pos,
radius,
max_num_neighbors=128,
include_self=True):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda
assert query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous()
assert query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
query_slices = batch_slices(query_batch)
max_num_neighbors = min(max_num_neighbors, sizes.max().item())
idx, cnt = query_radius(batch_size, slices, query_slices, pos, query_pos,
radius, max_num_neighbors, include_self)
# Convert to edges
view = idx.view(-1)
row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
row = row.view(-1, 1).repeat(1, max_num_neighbors).view(-1)
# Remove invalid indices
row = row[view != -1]
col = view[view != -1]
if col.size(0) == 0:
return col
edge_index = torch.stack([row, col], dim=0)
return edge_index
def radius_graph(batch, pos, radius, max_num_neighbors=128,
include_self=False):
return radius_query_edges(batch, pos, batch, pos, radius,
max_num_neighbors, include_self)
def knn_query_edges(batch,
pos,
query_batch,
query_pos,
num_neighbors,
include_self=True):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda
assert query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous()
assert query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
query_slices = batch_slices(query_batch)
assert (sizes < num_neighbors).sum().item() == 0
idx, dists = query_knn(batch_size, slices, query_slices, pos, query_pos,
num_neighbors, include_self)
# Convert to edges
view = idx.view(-1)
row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
row = row.view(-1, 1).repeat(1, num_neighbors).view(-1)
# Remove invalid indices
row = row[view != -1]
col = view[view != -1]
edge_index = torch.stack([row, col], dim=0)
return edge_index
def knn_graph(batch, pos, num_neighbors, include_self=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment