Commit 52cf0d12 authored by rusty1s's avatar rusty1s
Browse files

test GPU

parent 86f2e4a0
...@@ -63,13 +63,13 @@ def test_knn_graph(dtype, device): ...@@ -63,13 +63,13 @@ def test_knn_graph(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device): def test_knn_graph_large(dtype, device):
x = torch.randn(1000, 3) x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True, edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True,
num_workers=6) num_workers=6)
tree = scipy.spatial.cKDTree(x.numpy()) tree = scipy.spatial.cKDTree(x.cpu().numpy())
_, col = tree.query(x.cpu(), k=5) _, col = tree.query(x.cpu(), k=5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert to_set(edge_index) == truth assert to_set(edge_index.cpu()) == truth
...@@ -61,13 +61,13 @@ def test_radius_graph(dtype, device): ...@@ -61,13 +61,13 @@ def test_radius_graph(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph_large(dtype, device): def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3) x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
max_num_neighbors=2000, num_workers=6) max_num_neighbors=2000, num_workers=6)
tree = scipy.spatial.cKDTree(x.numpy()) tree = scipy.spatial.cKDTree(x.cpu().numpy())
col = tree.query_ball_point(x.cpu(), r=0.5) col = tree.query_ball_point(x.cpu(), r=0.5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert to_set(edge_index) == truth assert to_set(edge_index.cpu()) == truth
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment