Commit 2baa79fa authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

cpu radius version now max number of neighbors

parent dfe188ab
...@@ -28,7 +28,6 @@ def test_radius(dtype, device): ...@@ -28,7 +28,6 @@ def test_radius(dtype, device):
batch_y = tensor([0, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]] assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
...@@ -41,7 +40,7 @@ def test_radius_graph(dtype, device): ...@@ -41,7 +40,7 @@ def test_radius_graph(dtype, device):
[+1, -1], [+1, -1],
], dtype, device) ], dtype, device)
row, col = radius_graph(x, r=2) row, col = radius_graph(x, r=(2.0+1e-16))
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
...@@ -64,11 +64,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): ...@@ -64,11 +64,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1) y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x) tree = scipy.spatial.cKDTree(x)
col = tree.query_ball_point(y, r) _, col = tree.query(y, k=max_num_neighbors, distance_upper_bound=r)
col = [torch.tensor(c) for c in col] col = [torch.tensor(c) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)] row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0) row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
row = row[col<tree.n]
col = col[col<tree.n]
return torch.stack([row, col], dim=0) return torch.stack([row, col], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment