Commit e3055164 authored by rusty1s's avatar rusty1s
Browse files

fix tests on GPU

parent 547759a6
...@@ -90,13 +90,15 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, ...@@ -90,13 +90,15 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA(ptr_x.value()); CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1); CHECK_INPUT(ptr_x.value().dim() == 1);
} else { } else {
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong)); ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
} }
if (ptr_y.has_value()) { if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value()); CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1); CHECK_INPUT(ptr_y.value().dim() == 1);
} else { } else {
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong)); ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
} }
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
......
...@@ -58,13 +58,15 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, ...@@ -58,13 +58,15 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA(ptr_x.value()); CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1); CHECK_INPUT(ptr_x.value().dim() == 1);
} else { } else {
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong)); ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
} }
if (ptr_y.has_value()) { if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value()); CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1); CHECK_INPUT(ptr_y.value().dim() == 1);
} else { } else {
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong)); ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
} }
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
......
...@@ -17,7 +17,7 @@ torch::Tensor knn(torch::Tensor x, torch::Tensor y, ...@@ -17,7 +17,7 @@ torch::Tensor knn(torch::Tensor x, torch::Tensor y,
int64_t num_workers) { int64_t num_workers) {
if (x.device().is_cuda()) { if (x.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return knn_cuda(x, y, ptr_x, ptr_x, k, cosine); return knn_cuda(x, y, ptr_x, ptr_y, k, cosine);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
......
...@@ -8,6 +8,10 @@ from torch_cluster import knn, knn_graph ...@@ -8,6 +8,10 @@ from torch_cluster import knn, knn_graph
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn(dtype, device): def test_knn(dtype, device):
x = tensor([ x = tensor([
...@@ -28,18 +32,15 @@ def test_knn(dtype, device): ...@@ -28,18 +32,15 @@ def test_knn(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device)
row, col = knn(x, y, 2) edge_index = knn(x, y, 2)
assert row.tolist() == [0, 0, 1, 1] assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
assert col.tolist() == [2, 3, 0, 1]
row, col = knn(x, y, 2, batch_x, batch_y) edge_index = knn(x, y, 2, batch_x, batch_y)
assert row.tolist() == [0, 0, 1, 1] assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
assert col.tolist() == [2, 3, 4, 5]
if x.is_cuda: if x.is_cuda:
row, col = knn(x, y, 2, batch_x, batch_y, cosine=True) edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert row.tolist() == [0, 0, 1, 1] assert to_set(edge_index) == set([(0, 0), (0, 1), (1, 4), (1, 5)])
assert col.tolist() == [0, 1, 4, 5]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
...@@ -51,25 +52,24 @@ def test_knn_graph(dtype, device): ...@@ -51,25 +52,24 @@ def test_knn_graph(dtype, device):
[+1, -1], [+1, -1],
], dtype, device) ], dtype, device)
row, col = knn_graph(x, k=2, flow='target_to_source') edge_index = knn_graph(x, k=2, flow='target_to_source')
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] (2, 3), (3, 0), (3, 2)])
row, col = knn_graph(x, k=2, flow='source_to_target') edge_index = knn_graph(x, k=2, flow='source_to_target')
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] (3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device): def test_knn_graph_large(dtype, device):
x = torch.randn(1000, 3) x = torch.randn(1000, 3)
row, col = knn_graph(x, k=5, flow='target_to_source', loop=True, edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True,
num_workers=6) num_workers=6)
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
tree = scipy.spatial.cKDTree(x.numpy()) tree = scipy.spatial.cKDTree(x.numpy())
_, col = tree.query(x.cpu(), k=5) _, col = tree.query(x.cpu(), k=5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert pred == truth assert to_set(edge_index) == truth
...@@ -8,6 +8,10 @@ from torch_cluster import radius, radius_graph ...@@ -8,6 +8,10 @@ from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device): def test_radius(dtype, device):
x = tensor([ x = tensor([
...@@ -28,11 +32,13 @@ def test_radius(dtype, device): ...@@ -28,11 +32,13 @@ def test_radius(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, max_num_neighbors=4) edge_index = radius(x, y, 2, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1, 1, 1], [0, 1, 2, 3, 1, 2, 5, 6]] assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]] assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
...@@ -44,25 +50,24 @@ def test_radius_graph(dtype, device): ...@@ -44,25 +50,24 @@ def test_radius_graph(dtype, device):
[+1, -1], [+1, -1],
], dtype, device) ], dtype, device)
row, col = radius_graph(x, r=2, flow='target_to_source') edge_index = radius_graph(x, r=2, flow='target_to_source')
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] (2, 3), (3, 0), (3, 2)])
row, col = radius_graph(x, r=2, flow='source_to_target') edge_index = radius_graph(x, r=2, flow='source_to_target')
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] (3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph_large(dtype, device): def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3) x = torch.randn(1000, 3)
row, col = radius_graph(x, r=0.5, flow='target_to_source', loop=True, edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
max_num_neighbors=1000, num_workers=6) max_num_neighbors=1000, num_workers=6)
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
tree = scipy.spatial.cKDTree(x.numpy()) tree = scipy.spatial.cKDTree(x.numpy())
col = tree.query_ball_point(x.cpu(), r=0.5) col = tree.query_ball_point(x.cpu(), r=0.5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert pred == truth assert to_set(edge_index) == truth
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment