test_knn.py 2.76 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
from itertools import product

import pytest
import scipy.spatial
limm's avatar
limm committed
5
import torch
yangzhong's avatar
yangzhong committed
6
from torch_cluster import knn, knn_graph
limm's avatar
limm committed
7
from torch_cluster.testing import devices, grad_dtypes, tensor
yangzhong's avatar
yangzhong committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


def to_set(edge_index):
    return set([(i, j) for i, j in edge_index.t().tolist()])


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [1, 0],
        [-1, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    edge_index = knn(x, y, 2)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

limm's avatar
limm committed
37
38
39
40
    jit = torch.jit.script(knn)
    edge_index = jit(x, y, 2)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

yangzhong's avatar
yangzhong committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    edge_index = knn(x, y, 2, batch_x, batch_y)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

    if x.is_cuda:
        edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
        assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

    # Skipping a batch
    batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device)
    batch_y = tensor([0, 2], torch.long, device)
    edge_index = knn(x, y, 2, batch_x, batch_y)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    edge_index = knn_graph(x, k=2, flow='target_to_source')
    assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
                                      (2, 3), (3, 0), (3, 2)])

    edge_index = knn_graph(x, k=2, flow='source_to_target')
    assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
                                      (3, 2), (0, 3), (2, 3)])

limm's avatar
limm committed
72
73
74
75
76
    jit = torch.jit.script(knn_graph)
    edge_index = jit(x, k=2, flow='source_to_target')
    assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
                                      (3, 2), (0, 3), (2, 3)])

yangzhong's avatar
yangzhong committed
77
78
79
80
81
82
83
84
85
86
87
88

@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device):
    x = torch.randn(1000, 3, dtype=dtype, device=device)

    edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True)

    tree = scipy.spatial.cKDTree(x.cpu().numpy())
    _, col = tree.query(x.cpu(), k=5)
    truth = set([(i, j) for i, ns in enumerate(col) for j in ns])

    assert to_set(edge_index.cpu()) == truth