"cacheflow/vscode:/vscode.git/clone" did not exist on "8274ca23ac9f2ea0ccf758d1883794643aecb2e0"
test_consecutive.py 642 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import pytest
rusty1s's avatar
tests  
rusty1s committed
2
3
4
5
import torch
from torch_cluster.functions.utils.consecutive import consecutive


rusty1s's avatar
rusty1s committed
6
def test_consecutive_cpu():
rusty1s's avatar
tests  
rusty1s committed
7
8
9
10
11
12
    vec = torch.LongTensor([0, 2, 3])
    assert consecutive(vec).tolist() == [0, 1, 2]

    vec = torch.LongTensor([0, 3, 2, 2, 3])
    assert consecutive(vec).tolist() == [0, 2, 1, 1, 2]

rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20

@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_consecutive_gpu():  # pragma: no cover
    vec = torch.cuda.LongTensor([0, 2, 3])
    assert consecutive(vec).cpu().tolist() == [0, 1, 2]

    vec = torch.cuda.LongTensor([0, 3, 2, 2, 3])
    assert consecutive(vec).cpu().tolist() == [0, 2, 1, 1, 2]