Commit 07c92be4 authored by rusty1s's avatar rusty1s
Browse files

test graph gen

parent 9725bf76
......@@ -2,13 +2,13 @@ from itertools import product
import pytest
import torch
from torch_cluster import knn
from torch_cluster import knn, knn_graph
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
def test_knn(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
......@@ -27,9 +27,24 @@ def test_radius(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)
out = knn(x, y, 2, batch_x, batch_y)
assert out[0].tolist() == [0, 0, 1, 1]
col = out[1][:2].tolist()
assert col == [2, 3] or col == [3, 2]
col = out[1][2:].tolist()
assert col == [4, 5] or col == [5, 4]
row, col = knn(x, y, 2, batch_x, batch_y)
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 4, 5]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
row, col = knn_graph(x, k=2)
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
......@@ -2,7 +2,7 @@ from itertools import product
import pytest
import torch
from torch_cluster import radius
from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor
......@@ -28,4 +28,20 @@ def test_radius(dtype, device):
batch_y = tensor([0, 1], torch.long, device)
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
row, col = radius_graph(x, r=2)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment