You need to sign in or sign up before continuing.
Commit 9725bf76 authored by rusty1s's avatar rusty1s
Browse files

radius cpu version

parent 4e6cb0cf
......@@ -4,12 +4,9 @@ import pytest
import torch
from torch_cluster import radius
from .utils import tensor, grad_dtypes
from .utils import grad_dtypes, devices, tensor
devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
......
......@@ -46,8 +46,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
assert y.size(0) == batch_y.size(0)
if x.is_cuda:
assign_index = knn_cuda.knn(x, y, k, batch_x, batch_y)
return assign_index
return knn_cuda.knn(x, y, k, batch_x, batch_y)
# Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
......
import torch
import scipy.spatial
if torch.cuda.is_available():
import radius_cuda
......@@ -40,17 +41,25 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.is_cuda
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
op = radius_cuda.radius if x.is_cuda else None
assign_index = op(x, y, r, batch_x, batch_y, max_num_neighbors)
if x.is_cuda:
return radius_cuda.radius(x, y, r, batch_x, batch_y, max_num_neighbors)
return assign_index
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x)
col = tree.query_ball_point(y, r)
col = [torch.tensor(c) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
return torch.stack([row, col], dim=0)
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment