test_basis.py 1.95 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_spline_conv.basis import basis_forward
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
from .tensor import tensors
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
tests = [{
    'pseudo': [0, 0.0625, 0.25, 0.75, 0.9375, 1],
    'kernel_size': [5],
    'is_open_spline': [1],
    'basis': [[1, 0], [0.75, 0.25], [1, 0], [1, 0], [0.25, 0.75], [1, 0]],
    'weight_index': [[0, 1], [0, 1], [1, 2], [3, 4], [3, 4], [4, 0]],
}, {
    'pseudo': [0, 0.0625, 0.25, 0.75, 0.9375, 1],
    'kernel_size': [4],
    'is_open_spline': [0],
    'basis': [[1, 0], [0.75, 0.25], [1, 0], [1, 0], [0.25, 0.75], [1, 0]],
    'weight_index': [[0, 1], [0, 1], [1, 2], [3, 0], [3, 0], [0, 1]],
}, {
    'pseudo': [[0.125, 0.5], [0.5, 0.5], [0.75, 0.125]],
    'kernel_size': [5, 5],
    'is_open_spline': [1, 1],
    'basis': [[0.5, 0.5, 0, 0], [1, 0, 0, 0], [0.5, 0, 0.5, 0]],
    'weight_index': [[10, 11, 15, 16], [12, 13, 17, 18], [3, 4, 8, 9]]
}]
rusty1s's avatar
rusty1s committed
28
29


rusty1s's avatar
rusty1s committed
30
31
32
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_basis_forward_cpu(tensor, i):
    data = tests[i]
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
36
    pseudo = getattr(torch, tensor)(data['pseudo'])
    kernel_size = torch.LongTensor(data['kernel_size'])
    is_open_spline = torch.ByteTensor(data['is_open_spline'])
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
40
    basis, weight_index = basis_forward(1, pseudo, kernel_size, is_open_spline)
    assert basis.tolist() == data['basis']
    assert weight_index.tolist() == data['weight_index']
rusty1s's avatar
rusty1s committed
41
42
43


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
rusty1s's avatar
rusty1s committed
44
45
46
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_basis_forward_gpu(tensor, i):  # pragma: no cover
    data = tests[i]
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
50
    pseudo = getattr(torch.cuda, tensor)(data['pseudo'])
    kernel_size = torch.cuda.LongTensor(data['kernel_size'])
    is_open_spline = torch.cuda.ByteTensor(data['is_open_spline'])
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
54
    basis, weight_index = basis_forward(1, pseudo, kernel_size, is_open_spline)
    assert basis.cpu().tolist() == data['basis']
    assert weight_index.cpu().tolist() == data['weight_index']