Commit 10049daf authored by rusty1s's avatar rusty1s
Browse files

fix knn/radius for batches with zero-point examples

parent ae639fd6
...@@ -67,6 +67,9 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, ...@@ -67,6 +67,9 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1]; auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1]; auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];
if (x_start == x_end || y_start == y_end)
continue;
vec_t pts(x_end - x_start); vec_t pts(x_end - x_start);
for (int64_t i = 0; i < x_end - x_start; i++) { for (int64_t i = 0; i < x_end - x_start; i++) {
pts[i].resize(x.size(1)); pts[i].resize(x.size(1));
......
...@@ -70,6 +70,9 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y, ...@@ -70,6 +70,9 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1]; auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1]; auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];
if (x_start == x_end || y_start == y_end)
continue;
vec_t pts(x_end - x_start); vec_t pts(x_end - x_start);
for (int64_t i = 0; i < x_end - x_start; i++) { for (int64_t i = 0; i < x_end - x_start; i++) {
pts[i].resize(x.size(1)); pts[i].resize(x.size(1));
......
...@@ -42,6 +42,12 @@ def test_knn(dtype, device): ...@@ -42,6 +42,12 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True) edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
# Skipping a batch
batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device)
batch_y = tensor([0, 2], torch.long, device)
edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph(dtype, device): def test_knn_graph(dtype, device):
......
...@@ -40,6 +40,13 @@ def test_radius(dtype, device): ...@@ -40,6 +40,13 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)]) (1, 6)])
# Skipping a batch
batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device)
batch_y = tensor([0, 2], torch.long, device)
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph(dtype, device): def test_radius_graph(dtype, device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment