Commit 2ce0f235 authored by rusty1s's avatar rusty1s
Browse files

fix radius test

parent e0e5a84c
......@@ -50,11 +50,11 @@ def test_radius_graph(dtype, device):
[+1, -1],
], dtype, device)
edge_index = radius_graph(x, r=2, flow='target_to_source')
edge_index = radius_graph(x, r=2.5, flow='target_to_source')
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
(2, 3), (3, 0), (3, 2)])
edge_index = radius_graph(x, r=2, flow='source_to_target')
edge_index = radius_graph(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
......@@ -67,7 +67,7 @@ def test_radius_graph_large(dtype, device):
max_num_neighbors=2000, num_workers=6)
tree = scipy.spatial.cKDTree(x.numpy())
col = tree.query_ball_point(x.cpu(), r=0.5 + 0.00001)
col = tree.query_ball_point(x.cpu(), r=0.5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert to_set(edge_index) == truth
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment