test_nearest.py 1.97 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
from itertools import product

import pytest
import torch
from torch_cluster import nearest
limm's avatar
limm committed
6
from torch_cluster.testing import devices, grad_dtypes, tensor
yangzhong's avatar
yangzhong committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_nearest(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-2, -2],
        [-2, +2],
        [+2, +2],
        [+2, -2],
    ], dtype, device)
    y = tensor([
        [-1, 0],
        [+1, 0],
        [-2, 0],
        [+2, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 0, 1, 1], torch.long, device)

    out = nearest(x, y, batch_x, batch_y)
    assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]

    out = nearest(x, y)
    assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
limm's avatar
limm committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    # Invalid input: instance 1 only in batch_x
    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 0, 0, 0], torch.long, device)
    with pytest.raises(ValueError):
        nearest(x, y, batch_x, batch_y)

    # Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
    with pytest.raises(ValueError):
        nearest(x, y, batch_x, batch_y=None)

    # Invalid input: instance 2 only in batch_x
    # (i.e.instance in the middle missing)
    batch_x = tensor([0, 0, 1, 1, 2, 2, 3, 3], torch.long, device)
    batch_y = tensor([0, 1, 3, 3], torch.long, device)
    with pytest.raises(ValueError):
        nearest(x, y, batch_x, batch_y)

    # Invalid input: batch_x unsorted
    batch_x = tensor([0, 0, 1, 0, 0, 0, 0], torch.long, device)
    batch_y = tensor([0, 0, 1, 1], torch.long, device)
    with pytest.raises(ValueError):
        nearest(x, y, batch_x, batch_y)

    # Invalid input: batch_y unsorted
    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 0, 1, 0], torch.long, device)
    with pytest.raises(ValueError):
        nearest(x, y, batch_x, batch_y)