test_consecutive.py 473 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import pytest
rusty1s's avatar
tests  
rusty1s committed
2
3
4
5
import torch
from torch_cluster.functions.utils.consecutive import consecutive


rusty1s's avatar
rusty1s committed
6
def test_consecutive_cpu():
rusty1s's avatar
rusty1s committed
7
8
    x = torch.LongTensor([[0, 3, 2], [2, 3, 0]])
    assert consecutive(x).tolist() == [[0, 2, 1], [1, 2, 0]]
rusty1s's avatar
tests  
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
12

@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_consecutive_gpu():  # pragma: no cover
rusty1s's avatar
rusty1s committed
13
14
    x = torch.cuda.LongTensor([[0, 3, 2], [2, 3, 0]])
    assert consecutive(x).cpu().tolist() == [[0, 2, 1], [1, 2, 0]]