Unverified Commit 751dd81d authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #8 from mludolph/master

farthest point sampling and radius/knn query
parents f7af865f 95862f67
#include <torch/torch.h>
#define IS_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor");
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
#define CHECK_INPUT(x) IS_CUDA(x) IS_CONTIGUOUS(x)
std::vector<at::Tensor> query_radius_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
double radius,
int max_num_neighbors,
bool include_self);
std::vector<at::Tensor> query_radius(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
double radius,
int max_num_neighbors,
bool include_self) {
CHECK_INPUT(batch_slices);
CHECK_INPUT(query_batch_slices);
CHECK_INPUT(pos);
CHECK_INPUT(query_pos);
return query_radius_cuda(batch_size, batch_slices, query_batch_slices, pos, query_pos, radius, max_num_neighbors, include_self);
}
std::vector<at::Tensor> query_knn_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
int num_neighbors,
bool include_self);
std::vector<at::Tensor> query_knn(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
int num_neighbors,
bool include_self) {
CHECK_INPUT(batch_slices);
CHECK_INPUT(query_batch_slices);
CHECK_INPUT(pos);
CHECK_INPUT(query_pos);
return query_knn_cuda(batch_size, batch_slices, query_batch_slices, pos, query_pos, num_neighbors, include_self);
}
at::Tensor farthest_point_sampling_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor pos,
int num_sample,
at::Tensor start_points);
at::Tensor farthest_point_sampling(
int batch_size,
at::Tensor batch_slices,
at::Tensor pos,
int num_sample,
at::Tensor start_points) {
CHECK_INPUT(batch_slices);
CHECK_INPUT(pos);
CHECK_INPUT(start_points);
return farthest_point_sampling_cuda(batch_size, batch_slices, pos, num_sample, start_points);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("query_radius", &query_radius, "Query Radius (CUDA)");
m.def("query_knn", &query_knn, "Query K-Nearest neighbor (CUDA)");
m.def("farthest_point_sampling", &farthest_point_sampling, "Farthest Point Sampling (CUDA)");
}
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdint.h>
#include <vector>
#define THREADS 1024
// Original by Qi. et al (https://github.com/charlesq34/pointnet2)
template <typename scalar_t>
__global__ void query_radius_cuda_kernel(
const int64_t* __restrict__ batch_slices,
const int64_t* __restrict__ query_batch_slices,
const scalar_t* __restrict__ pos,
const scalar_t* __restrict__ query_pos,
const scalar_t radius,
const int64_t max_num_neighbors,
const bool include_self,
int64_t* idx_output,
int64_t* cnt_output)
{
const int64_t batch_index = blockIdx.x;
const int64_t index = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t batch_start = batch_slices[2*batch_index];
const int64_t query_batch_start = query_batch_slices[2*batch_index];
const int64_t batch_end = batch_slices[2*batch_index+1];
const int64_t query_batch_end = query_batch_slices[2*batch_index+1];
const int64_t batch_size = batch_end - batch_start + 1;
const int64_t query_batch_size = query_batch_end - query_batch_start + 1;
pos += batch_start * 3;
query_pos += query_batch_start * 3;
idx_output += query_batch_start * max_num_neighbors;
cnt_output += query_batch_start;
for (int64_t j = index; j < query_batch_size; j+=stride){
int64_t cnt = 0;
scalar_t x2=query_pos[j*3+0];
scalar_t y2=query_pos[j*3+1];
scalar_t z2=query_pos[j*3+2];
// dummy outputs initialisation with value -1
if (cnt==0) {
for (int64_t l = 0;l < max_num_neighbors; l++)
idx_output[j*max_num_neighbors+l] = -1;
}
for (int64_t k = 0; k < batch_size; k++) {
if (cnt == max_num_neighbors)
break;
scalar_t x1=pos[k*3+0];
scalar_t y1=pos[k*3+1];
scalar_t z1=pos[k*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
if (d <= radius && (d > 0 || include_self)) {
idx_output[j * max_num_neighbors + cnt] = batch_start + k;
cnt+=1;
}
}
cnt_output[j] = cnt;
}
}
template <typename scalar_t>
__global__ void query_knn_cuda_kernel(
const int64_t* __restrict__ batch_slices,
const int64_t* __restrict__ query_batch_slices,
const scalar_t* __restrict__ pos,
const scalar_t* __restrict__ query_pos,
const int64_t num_neighbors,
const bool include_self,
scalar_t* tmp_dists,
int64_t* idx_output){
const int64_t batch_index = blockIdx.x;
const int64_t index = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t batch_start = batch_slices[2*batch_index];
const int64_t query_batch_start = query_batch_slices[2*batch_index];
const int64_t batch_end = batch_slices[2*batch_index+1];
const int64_t query_batch_end = query_batch_slices[2*batch_index+1];
const int64_t batch_size = batch_end - batch_start + 1;
const int64_t query_batch_size = query_batch_end - query_batch_start + 1;
pos += batch_start * 3;
query_pos += query_batch_start * 3;
idx_output += query_batch_start * num_neighbors;
tmp_dists += query_batch_start * num_neighbors;
for (int64_t j = index; j < query_batch_size; j += stride){
scalar_t x2=query_pos[j*3+0];
scalar_t y2=query_pos[j*3+1];
scalar_t z2=query_pos[j*3+2];
// reset to dummy values
for (int64_t l = 0; l < num_neighbors; l++){
idx_output[j * num_neighbors + l] = -1;
tmp_dists[j * num_neighbors + l] = 2147483647;
}
for (int64_t k = 0; k < batch_size; k++) {
scalar_t x1=pos[k*3+0];
scalar_t y1=pos[k*3+1];
scalar_t z1=pos[k*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
if (d > 0 || include_self){
for (int64_t i = 0; i < num_neighbors; i++){
if (tmp_dists[j * num_neighbors + i] > d){
for (int64_t i2 = num_neighbors-1; i2 > i; i2--){
tmp_dists[j * num_neighbors + i2] = tmp_dists[j * num_neighbors + i2 - 1];
idx_output[j * num_neighbors + i2] = idx_output[j * num_neighbors + i2 - 1];
}
tmp_dists[j * num_neighbors + i] = d;
idx_output[j * num_neighbors + i] = batch_start + k;
break;
}
}
}
}
}
}
template <typename scalar_t>
__global__ void farthest_point_sampling_kernel(
const int64_t* __restrict__ batch_slices,
const scalar_t* __restrict__ pos,
const int64_t num_sample,
const int64_t* __restrict__ start_points,
scalar_t* tmp_dists,
int64_t* idx_output){
const int64_t batch_index = blockIdx.x;
const int64_t index = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t batch_start = batch_slices[2*batch_index];
const int64_t batch_end = batch_slices[2*batch_index+1];
const int64_t batch_size = batch_end - batch_start + 1;
__shared__ scalar_t dists[THREADS];
__shared__ int64_t dists_i[THREADS];
pos += batch_start * 3;
idx_output += num_sample * batch_index;
tmp_dists += batch_start;
int64_t old = start_points[batch_index];
// explicitly handle the case where less than num_sample points are available to sample from
if (index == 0){
idx_output[0] = batch_start + old;
if (batch_size < num_sample){
for (int64_t i = 0; i < batch_size; i++){
idx_output[i] = batch_start + i;
}
for (int64_t i = batch_size; i < num_sample; i++){
idx_output[i] = -1;
}
}
}
if (batch_size < num_sample){
return;
}
// initialise temporary distances as big as possible
for (int64_t j = index; j < batch_size; j+=stride){
tmp_dists[j] = 2147483647;
}
__syncthreads();
for (int64_t i = 1; i < num_sample; i++){
int64_t besti = -1;
scalar_t best = -1;
// compute distance from last point to all others and update using the minimum of already computed distances
for (int64_t j = index; j < batch_size; j+= stride){
scalar_t td = tmp_dists[j];
scalar_t x1 = pos[old * 3 + 0];
scalar_t y1 = pos[old * 3 + 1];
scalar_t z1 = pos[old * 3 + 2];
scalar_t x2 = pos[j * 3 + 0];
scalar_t y2 = pos[j * 3 + 1];
scalar_t z2 = pos[j * 3 + 2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
scalar_t d2 = min(d, tmp_dists[j]);
if (td != d2){
tmp_dists[j] = d2;
}
if (tmp_dists[j] > best){
best = tmp_dists[j];
besti = j;
}
}
// sort best indices
dists[index] = best;
dists_i[index] = besti;
__syncthreads();
// get the maximum distances (by merging)
for (int64_t u = 0; (1<<u) < blockDim.x ; u++){
__syncthreads();
if (index < (blockDim.x >> (u+1))){
int64_t i1 = (index*2)<<u;
int64_t i2 = (index*2+1)<<u;
if (dists[i1] < dists[i2]){
dists[i1] = dists[i2];
dists_i[i1] = dists_i[i2];
}
}
}
__syncthreads();
if (dists[0] == 0){
break;
}
// thread 0 collects in output
old = dists_i[0];
if (index == 0){
idx_output[i] = batch_start + old;
}
}
}
std::vector<at::Tensor> query_radius_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
double radius,
int max_num_neighbors,
bool include_self) {
const auto num_points = query_pos.size(0);
auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {num_points, max_num_neighbors});
auto cnt_output = at::empty(pos.type().toScalarType(at::kLong), {num_points});
AT_DISPATCH_FLOATING_TYPES(pos.type(), "query_radius_cuda_kernel", [&] {
query_radius_cuda_kernel<scalar_t><<<batch_size, THREADS>>>(
batch_slices.data<int64_t>(),
query_batch_slices.data<int64_t>(),
pos.data<scalar_t>(),
query_pos.data<scalar_t>(),
(scalar_t) radius*radius,
max_num_neighbors,
include_self,
idx_output.data<int64_t>(),
cnt_output.data<int64_t>());
});
return {idx_output, cnt_output};
}
std::vector<at::Tensor> query_knn_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor query_batch_slices,
at::Tensor pos,
at::Tensor query_pos,
int num_neighbors,
bool include_self) {
const auto num_points = query_pos.size(0);
auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {num_points, num_neighbors});
auto dists = at::empty(pos.type(), {num_points, num_neighbors});
AT_DISPATCH_FLOATING_TYPES(pos.type(), "query_knn_cuda_kernel", [&] {
query_knn_cuda_kernel<scalar_t><<<batch_size, THREADS>>>(
batch_slices.data<int64_t>(),
query_batch_slices.data<int64_t>(),
pos.data<scalar_t>(),
query_pos.data<scalar_t>(),
num_neighbors,
include_self,
dists.data<scalar_t>(),
idx_output.data<int64_t>());
});
return {idx_output, dists};
}
at::Tensor farthest_point_sampling_cuda(
int batch_size,
at::Tensor batch_slices,
at::Tensor pos,
int num_sample,
at::Tensor start_points) {
auto idx_output = at::empty(pos.type().toScalarType(at::kLong), {batch_size * num_sample});
auto tmp_dists = at::empty(pos.type(), {pos.size(0)});
AT_DISPATCH_FLOATING_TYPES(pos.type(), "farthest_point_sampling_kernel", [&] {
farthest_point_sampling_kernel<scalar_t><<<batch_size, THREADS>>>(
batch_slices.data<int64_t>(),
pos.data<scalar_t>(),
num_sample,
start_points.data<int64_t>(),
tmp_dists.data<scalar_t>(),
idx_output.data<int64_t>());
});
return idx_output;
}
import pytest
import torch
import numpy as np
from torch_geometric.data import Batch
from numpy.testing import assert_almost_equal
from capsules.utils.sample import sample_farthest, batch_slices, radius_query_edges
from .utils import tensor, grad_dtypes, devices
@pytest.mark.parametrize('device', devices)
def test_batch_slices(device):
# test sample case for correctness
batch = tensor([0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device)
slices, sizes = batch_slices(batch, sizes=True)
slices, sizes = slices.cpu().tolist(), sizes.cpu().tolist()
assert slices == [0, 99, 100, 149, 150, 191]
assert sizes == [100, 50, 42]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype', grad_dtypes)
def test_fps(dtype):
# test simple case for correctness
batch = [0] * 10
points = [[-1, -1, 0], [-1, 1, 0], [1, 1, 0], [1, -1, 0]]
random_points = np.random.uniform(-1, 1, size=(6, 3))
random_points[:, 2] = 0
random_points = random_points.tolist()
batch = tensor(batch, dtype=torch.long, device='cuda')
pos = tensor(points + random_points, dtype=dtype, device='cuda')
idx = sample_farthest(batch, pos, num_sampled=4, index=True)
# needs update since isin is missing (sort indices, then compare?)
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'), False).all().cpu().item() == 1
# test variable number of points for each element in a batch
batch = [0] * 100 + [1] * 50
points1 = np.random.uniform(-1, 1, size=(100, 3)).tolist()
points2 = np.random.uniform(-1, 1, size=(50, 3)).tolist()
batch = tensor(batch, dtype=torch.long, device='cuda')
pos = tensor(points1 + points2, dtype=dtype, device='cuda')
mask = sample_farthest(batch, pos, num_sampled=75, index=False)
assert mask[batch == 0].sum().item() == 75
assert mask[batch == 1].sum().item() == 50
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype', grad_dtypes)
def test_radius_edges(dtype):
batch = [0] * 100 + [1] * 50 + [2] * 42
points = np.random.uniform(-1, 1, size=(192, 3))
query_batch = [0] * 10 + [1] * 15 + [2] * 20
query_points = np.random.uniform(-1, 1, size=(45, 3))
radius = 0.5
batch = tensor(batch, dtype=torch.long, device='cuda')
query_batch = tensor(query_batch, dtype=torch.long, device='cuda')
pos = tensor(points, dtype=dtype, device='cuda')
query_pos = tensor(query_points, dtype=dtype, device='cuda')
edge_index = radius_query_edges(batch, pos, query_batch, query_pos, radius=radius, max_num_neighbors=128)
row, col = edge_index
dist = torch.norm(pos[col] - query_pos[row], p=2, dim=1)
assert (dist <= radius).all().item()
import torch
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import to_undirected
from torch_geometric.data import Batch
from torch_sparse import coalesce
from sample_cuda import farthest_point_sampling, query_radius, query_knn
def batch_slices(batch, sizes=False, include_ends=True):
"""
Calculates size, start and end indices for each element in a batch.
"""
size = scatter_add(torch.ones_like(batch), batch)
cumsum = torch.cumsum(size, dim=0)
starts = cumsum - size
ends = cumsum - 1
slices = starts
if include_ends:
slices = torch.stack([starts, ends], dim=1).view(-1)
if sizes:
return slices, size
return slices
def sample_farthest(batch, pos, num_sampled, random_start=False, index=False):
"""
Samples a specified number of points for each element in a batch using farthest iterative point sampling and returns
a mask (or indices) for the sampled points.
If there are less than num_sampled points in a point cloud all points are returned.
"""
if not pos.is_cuda or not batch.is_cuda:
raise NotImplementedError
assert pos.is_contiguous() and batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
if random_start:
random = torch.rand(batch_size, device=slices.device)
start_points = (sizes.float() * random).long()
else:
start_points = torch.zeros_like(sizes)
idx = farthest_point_sampling(batch_size, slices, pos, num_sampled,
start_points)
# Remove invalid indices
idx = idx[idx != -1]
if index:
return idx
mask = torch.zeros(pos.size(0), dtype=torch.uint8, device=pos.device)
mask[idx] = 1
return mask
def radius_query_edges(batch,
pos,
query_batch,
query_pos,
radius,
max_num_neighbors=128,
include_self=True,
undirected=False):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous(
) and query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
query_slices = batch_slices(query_batch)
max_num_neighbors = min(max_num_neighbors, sizes.max().item())
idx, cnt = query_radius(batch_size, slices, query_slices, pos, query_pos,
radius, max_num_neighbors, include_self)
# Convert to edges
view = idx.view(-1)
row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
row = row.view(-1, 1).repeat(1, max_num_neighbors).view(-1)
# Remove invalid indices
row = row[view != -1]
col = view[view != -1]
if col.size(0) == 0:
return col
edge_index = torch.stack([row, col], dim=0)
if undirected:
return to_undirected(edge_index, query_pos.size(0))
return edge_index
def radius_graph(batch,
pos,
radius,
max_num_neighbors=128,
include_self=False,
undirected=False):
return radius_query_edges(batch, pos, batch, pos, radius,
max_num_neighbors, include_self, undirected)
def knn_query_edges(batch,
pos,
query_batch,
query_pos,
num_neighbors,
include_self=True,
undirected=False):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous(
) and query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
query_slices = batch_slices(query_batch)
assert (sizes < num_neighbors).sum().item() == 0
idx, dists = query_knn(batch_size, slices, query_slices, pos, query_pos,
num_neighbors, include_self)
# Convert to edges
view = idx.view(-1)
row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
row = row.view(-1, 1).repeat(1, num_neighbors).view(-1)
# Remove invalid indices
row = row[view != -1]
col = view[view != -1]
edge_index = torch.stack([row, col], dim=0)
if undirected:
return to_undirected(edge_index, query_pos.size(0))
return edge_index
def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self,
undirected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment