Commit c7aeb8d2 authored by rusty1s's avatar rusty1s
Browse files

randomly sample entries if entries exceed max_num_neighbors

parent 73412857
......@@ -6,6 +6,8 @@ from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor
grad_dtypes = [torch.float]
def coalesce(index):
N = index.max().item() + 1
......
......@@ -5,6 +5,12 @@ if torch.cuda.is_available():
import torch_cluster.radius_cuda
def sample(col, count):
if col.size(0) > count:
col = col[torch.randperm(col.size(0))][:count]
return col
def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
......@@ -64,20 +70,15 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
_, col = tree.query(
y.detach().numpy(), k=max_num_neighbors, distance_upper_bound=r + 1e-8)
col = [torch.from_numpy(c).to(torch.long) for c in col]
col = tree.query_ball_point(y.detach().numpy(), r)
col = [sample(torch.tensor(c), max_num_neighbors) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
mask = col < int(tree.n)
return torch.stack([row[mask], col[mask]], dim=0)
def radius_graph(x,
r,
batch=None,
loop=False,
max_num_neighbors=32,
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32,
flow='source_to_target'):
r"""Computes graph edges to all points within a given distance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment