Commit 0044fc97 authored by rusty1s's avatar rusty1s
Browse files

radius graph

parent ea7df4db
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
at::Tensor batch_x, at::Tensor batch_y,
size_t max_num_neighbors);
at::Tensor radius(at::Tensor x, at::Tensor y, float radius, at::Tensor batch_x,
at::Tensor batch_y, size_t max_num_neighbors) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return radius_cuda(x, y, radius, batch_x, batch_y, max_num_neighbors);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("radius", &radius, "Radius (CUDA)");
}
#include <ATen/ATen.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void
radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y, int64_t *__restrict__ row,
int64_t *__restrict__ col, scalar_t radius,
size_t max_num_neighbors, size_t dim) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t start_idx_x = batch_x[batch_idx];
const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];
const ptrdiff_t start_idx_y = batch_y[batch_idx];
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
for (ptrdiff_t n_y = start_idx_y + idx; n_y < end_idx_y; n_y += THREADS) {
size_t count = 0;
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {
scalar_t dist = 0;
for (ptrdiff_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
dist = sqrt(dist);
if (dist <= radius) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
}
if (count >= max_num_neighbors) {
continue;
}
}
}
}
at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
at::Tensor batch_x, at::Tensor batch_y,
size_t max_num_neighbors) {
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;
batch_x = degree(batch_x, batch_size);
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
auto row = at::full(y.size(0) * max_num_neighbors, -1, batch_y.options());
auto col = at::full(y.size(0) * max_num_neighbors, -1, batch_y.options());
AT_DISPATCH_FLOATING_TYPES(x.type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<batch_size, THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
batch_y.data<int64_t>(), row.data<int64_t>(), col.data<int64_t>(),
radius, max_num_neighbors, x.size(1));
});
auto mask = row != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
......@@ -16,6 +16,8 @@ if torch.cuda.is_available():
CUDAExtension('fps_cuda', ['cuda/fps.cpp', 'cuda/fps_kernel.cu']),
CUDAExtension('nearest_cuda',
['cuda/nearest.cpp', 'cuda/nearest_kernel.cu']),
CUDAExtension('radius_cuda',
['cuda/radius.cpp', 'cuda/radius_kernel.cu']),
]
__version__ = '1.2.0'
......
from itertools import product
import pytest
import torch
from torch_cluster import radius
from .utils import tensor
devices = [torch.device('cuda')]
grad_dtypes = [torch.float]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
y = tensor([
[0, 0],
[0, 1],
], dtype, device)
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
print()
print([[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]])
print(out)
......@@ -2,6 +2,7 @@ from .graclus import graclus_cluster
from .grid import grid_cluster
from .fps import fps
from .nearest import nearest
from .radius import radius, radius_graph
__version__ = '1.2.0'
......@@ -10,5 +11,7 @@ __all__ = [
'grid_cluster',
'fps',
'nearest',
'radius',
'radius_graph',
'__version__',
]
import torch
if torch.cuda.is_available():
import radius_cuda
def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.is_cuda
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
op = radius_cuda.radius if x.is_cuda else None
assign_index = op(x, y, r, batch_x, batch_y, max_num_neighbors)
return assign_index
def radius_graph(x, r, batch=None, max_num_neighbors=32):
edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment