test_permute.py 1.96 KB
Newer Older
rusty1s's avatar
tests  
rusty1s committed
1
2
3
4
5
6
import pytest
import torch
from torch_cluster.functions.utils.permute import sort, permute


def test_sort_cpu():
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
    row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
    row, col = sort(row, col)
    expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
    expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
    assert row.tolist() == expected_row
    assert col.tolist() == expected_col
rusty1s's avatar
tests  
rusty1s committed
14
15
16


def test_permute_cpu():
rusty1s's avatar
rusty1s committed
17
18
    row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
rusty1s's avatar
tests  
rusty1s committed
19
20
    node_rid = torch.LongTensor([2, 1, 3, 0])
    edge_rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
rusty1s's avatar
rusty1s committed
21
22
23
24
25
    row, col = permute(row, col, 4, node_rid, edge_rid)
    expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
    expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
    assert row.tolist() == expected_row
    assert col.tolist() == expected_col
rusty1s's avatar
tests  
rusty1s committed
26
27
28
29


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_sort_gpu():  # pragma: no cover
rusty1s's avatar
rusty1s committed
30
31
32
    row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
    row, col = sort(row, col)
rusty1s's avatar
tests  
rusty1s committed
33
    expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
rusty1s's avatar
rusty1s committed
34
35
36
    expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
    assert row.cpu().tolist() == expected_row
    assert col.cpu().tolist() == expected_col
rusty1s's avatar
tests  
rusty1s committed
37
38
39
40


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_permute_gpu():  # pragma: no cover
rusty1s's avatar
rusty1s committed
41
42
    row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
    col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
rusty1s's avatar
tests  
rusty1s committed
43
    node_rid = torch.cuda.LongTensor([2, 1, 3, 0])
rusty1s's avatar
rusty1s committed
44
45
    edge_rid = torch.cuda.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    row, col = permute(row, col, 4, node_rid, edge_rid)
rusty1s's avatar
tests  
rusty1s committed
46
    expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
rusty1s's avatar
rusty1s committed
47
48
49
    expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
    assert row.cpu().tolist() == expected_row
    assert col.cpu().tolist() == expected_col