Commit 5549ef86 authored by rusty1s's avatar rusty1s
Browse files

no sort

parent 79f67548
......@@ -7,6 +7,13 @@ from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor
def coalesce(index):
N = index.max().item() + 1
tensor = torch.sparse_coo_tensor(index, index.new_ones(index.size(1)),
torch.Size([N, N]))
return tensor.coalesce().indices()
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
......@@ -28,7 +35,7 @@ def test_radius(dtype, device):
batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
assert coalesce(out).tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......@@ -40,7 +47,6 @@ def test_radius_graph(dtype, device):
[+1, -1],
], dtype, device)
row, col = radius_graph(x, r=(2.0+1e-16))
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
out = radius_graph(x, r=2)
assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3],
[1, 3, 0, 2, 1, 3, 0, 2]]
......@@ -65,7 +65,7 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x)
_, col = tree.query(y, k=max_num_neighbors, distance_upper_bound=r)
_, col = tree.query(y, k=max_num_neighbors, distance_upper_bound=r + 1e-8)
col = [torch.tensor(c) for c in col]
row = [torch.full_like(c, i) for i, c in enumerate(col)]
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment