test_permute.py 2.24 KB
Newer Older
rusty1s's avatar
tests  
rusty1s committed
1
2
3
4
5
6
import pytest
import torch
from torch_cluster.functions.utils.permute import sort, permute


def test_sort_cpu():
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
    row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
    row, col = sort(row, col)
    expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
    expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
    assert row.tolist() == expected_row
    assert col.tolist() == expected_col
rusty1s's avatar
tests  
rusty1s committed
14
15
16


def test_permute_cpu():
rusty1s's avatar
rusty1s committed
17
18
    row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
rusty1s's avatar
tests  
rusty1s committed
19
20
    node_rid = torch.LongTensor([2, 1, 3, 0])
    edge_rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
rusty1s's avatar
rusty1s committed
21
22
23
24
25
    row, col = permute(row, col, 4, node_rid, edge_rid)
    expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
    expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
    assert row.tolist() == expected_row
    assert col.tolist() == expected_col
rusty1s's avatar
tests  
rusty1s committed
26
27
28
29


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_sort_gpu():  # pragma: no cover
rusty1s's avatar
rusty1s committed
30
31
32
    # Note that `sort` is not stable on the GPU, so it does not preserve the
    # relative ordering of equivalent row elements. Thus, the expected column
    # vector differs from the CPU version (which is stable).
rusty1s's avatar
rusty1s committed
33
34
35
    row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
    row, col = sort(row, col)
rusty1s's avatar
tests  
rusty1s committed
36
    expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
rusty1s's avatar
rusty1s committed
37
38
39
    expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
    assert row.cpu().tolist() == expected_row
    assert col.cpu().tolist() == expected_col
rusty1s's avatar
tests  
rusty1s committed
40
41
42
43


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_permute_gpu():  # pragma: no cover
rusty1s's avatar
rusty1s committed
44
    # Equivalent to `sort`, `permute` is not stable on the GPU (see above).
rusty1s's avatar
rusty1s committed
45
46
    row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
rusty1s's avatar
tests  
rusty1s committed
47
    node_rid = torch.cuda.LongTensor([2, 1, 3, 0])
rusty1s's avatar
rusty1s committed
48
49
    edge_rid = torch.cuda.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    row, col = permute(row, col, 4, node_rid, edge_rid)
rusty1s's avatar
tests  
rusty1s committed
50
    expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
rusty1s's avatar
rusty1s committed
51
52
53
    expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
    assert row.cpu().tolist() == expected_row
    assert col.cpu().tolist() == expected_col