test_knn.py 844 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
from itertools import product

import pytest
import torch
from torch_cluster import knn

rusty1s's avatar
rusty1s committed
7
from .utils import grad_dtypes, devices, tensor
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [1, 0],
        [-1, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    out = knn(x, y, 2, batch_x, batch_y)
rusty1s's avatar
rusty1s committed
31
32
33
34
35
    assert out[0].tolist() == [0, 0, 1, 1]
    col = out[1][:2].tolist()
    assert col == [2, 3] or col == [3, 2]
    col = out[1][2:].tolist()
    assert col == [4, 5] or col == [5, 4]