test_permute.py 1.66 KB
Newer Older
rusty1s's avatar
tests  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pytest
import torch
from torch_cluster.functions.utils.permute import sort, permute


def test_sort_cpu():
    edge_index = torch.LongTensor([
        [0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
        [1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
    ])
    expected_edge_index = [
        [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
        [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    ]
    assert sort(edge_index).tolist() == expected_edge_index


def test_permute_cpu():
    edge_index = torch.LongTensor([
        [0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
        [1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
    ])
    node_rid = torch.LongTensor([2, 1, 3, 0])
    edge_rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

    edge_index = permute(edge_index, 4, node_rid, edge_rid)
    expected_edge_index = [
        [3, 3, 1, 1, 1, 0, 0, 2, 2, 2],
        [1, 2, 0, 2, 3, 1, 2, 0, 1, 3],
    ]

    assert edge_index.tolist() == expected_edge_index


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_sort_gpu():  # pragma: no cover
    edge_index = torch.cuda.LongTensor([
        [0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
        [1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
    ])
    expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
    assert sort(edge_index)[0].cpu().tolist() == expected_row


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_permute_gpu():  # pragma: no cover
    edge_index = torch.cuda.LongTensor([
        [0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
        [1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
    ])
    node_rid = torch.cuda.LongTensor([2, 1, 3, 0])

    edge_index = permute(edge_index, 4, node_rid)
    expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]

    assert edge_index[0].cpu().tolist() == expected_row