test_knn.py 3.37 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_cluster import knn, knn_graph
rusty1s's avatar
rusty1s committed
6
from .utils import grad_dtypes, devices, tensor
rusty1s's avatar
rusty1s committed
7
8
9


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
rusty1s's avatar
rusty1s committed
10
def test_knn(dtype, device):
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [1, 0],
        [-1, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
    row, col = knn(x, y, 2, batch_x, batch_y)
    col = col.view(-1, 2).sort(dim=-1)[0].view(-1)

    assert row.tolist() == [0, 0, 1, 1]
    assert col.tolist() == [2, 3, 4, 5]

rusty1s's avatar
rusty1s committed
35
36
37
38
39
    if x.is_cuda:
        row, col = knn(x, y, 2, batch_x, batch_y, cosine=True)
        assert row.tolist() == [0, 0, 1, 1]
        assert col.tolist() == [0, 1, 4, 5]

rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49

@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

rusty1s's avatar
rusty1s committed
50
    row, col = knn_graph(x, k=2, flow='target_to_source')
rusty1s's avatar
rusty1s committed
51
52
53
    col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
    assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
    assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
rusty1s's avatar
rusty1s committed
54
55
56
57
58

    row, col = knn_graph(x, k=2, flow='source_to_target')
    row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
    assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
    assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
59
60
61
62


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_large(dtype, device):
Alexander Liao's avatar
Alexander Liao committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    x = torch.tensor([[-1.0320,  0.2380,  0.2380],
                      [-1.3050, -0.0930,  0.6420],
                      [-0.3190, -0.0410,  1.2150],
                      [1.1400, -0.5390, -0.3140],
                      [0.8410,  0.8290,  0.6090],
                      [-1.4380, -0.2420, -0.3260],
                      [-2.2980,  0.7160,  0.9320],
                      [-1.3680, -0.4390,  0.1380],
                      [-0.6710,  0.6060,  1.1800],
                      [0.3950, -0.0790,  1.4920]],).to(device)
    k = 3
    truth = set({(4, 8), (2, 8), (9, 8), (8, 0), (0, 7), (2, 1), (9, 4),
                 (5, 1), (4, 9), (2, 9), (8, 1), (1, 5), (5, 0), (3, 2),
                 (8, 2), (7, 1), (6, 0), (3, 9), (0, 5), (7, 5), (4, 2),
                 (1, 0), (0, 1), (7, 0), (6, 8), (9, 2), (6, 1), (5, 7),
                 (1, 7), (3, 4)})

    row, col = knn_graph(x, k=k, flow='target_to_source',
                         batch=None, n_threads=24, loop=False)
82
83
84
85
86
87

    edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
                                          list(col.cpu().numpy()))])

    assert(truth == edges)

Alexander Liao's avatar
Alexander Liao committed
88
89
    row, col = knn_graph(x, k=k, flow='target_to_source',
                         batch=None, n_threads=12, loop=False)
90
91
92
93
94
95

    edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
                                          list(col.cpu().numpy()))])

    assert(truth == edges)

Alexander Liao's avatar
Alexander Liao committed
96
97
    row, col = knn_graph(x, k=k, flow='target_to_source',
                         batch=None, n_threads=1, loop=False)
98
99
100
101
102

    edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
                                          list(col.cpu().numpy()))])

    assert(truth == edges)