Commit e3055164 authored by rusty1s's avatar rusty1s
Browse files

fix tests on GPU

parent 547759a6
......@@ -90,13 +90,15 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else {
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else {
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
......
......@@ -58,13 +58,15 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else {
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else {
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
......
......@@ -17,7 +17,7 @@ torch::Tensor knn(torch::Tensor x, torch::Tensor y,
int64_t num_workers) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return knn_cuda(x, y, ptr_x, ptr_x, k, cosine);
return knn_cuda(x, y, ptr_x, ptr_y, k, cosine);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
......
......@@ -8,6 +8,10 @@ from torch_cluster import knn, knn_graph
from .utils import grad_dtypes, devices, tensor
def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn(dtype, device):
x = tensor([
......@@ -28,18 +32,15 @@ def test_knn(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)
row, col = knn(x, y, 2)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 0, 1]
edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
row, col = knn(x, y, 2, batch_x, batch_y)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 4, 5]
edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
if x.is_cuda:
row, col = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [0, 1, 4, 5]
edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert to_set(edge_index) == set([(0, 0), (0, 1), (1, 4), (1, 5)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......@@ -51,25 +52,24 @@ def test_knn_graph(dtype, device):
[+1, -1],
], dtype, device)
row, col = knn_graph(x, k=2, flow='target_to_source')
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
edge_index = knn_graph(x, k=2, flow='target_to_source')
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
(2, 3), (3, 0), (3, 2)])
row, col = knn_graph(x, k=2, flow='source_to_target')
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
edge_index = knn_graph(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device):
x = torch.randn(1000, 3)
row, col = knn_graph(x, k=5, flow='target_to_source', loop=True,
num_workers=6)
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True,
num_workers=6)
tree = scipy.spatial.cKDTree(x.numpy())
_, col = tree.query(x.cpu(), k=5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert pred == truth
assert to_set(edge_index) == truth
......@@ -8,6 +8,10 @@ from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor
def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
......@@ -28,11 +32,13 @@ def test_radius(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1, 1, 1], [0, 1, 2, 3, 1, 2, 5, 6]]
edge_index = radius(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......@@ -44,25 +50,24 @@ def test_radius_graph(dtype, device):
[+1, -1],
], dtype, device)
row, col = radius_graph(x, r=2, flow='target_to_source')
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
edge_index = radius_graph(x, r=2, flow='target_to_source')
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
(2, 3), (3, 0), (3, 2)])
row, col = radius_graph(x, r=2, flow='source_to_target')
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
edge_index = radius_graph(x, r=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3)
row, col = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
max_num_neighbors=1000, num_workers=6)
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
max_num_neighbors=1000, num_workers=6)
tree = scipy.spatial.cKDTree(x.numpy())
col = tree.query_ball_point(x.cpu(), r=0.5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
assert pred == truth
assert to_set(edge_index) == truth
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment