test_basis.py 1.11 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
from os import path as osp
from itertools import product

import pytest
import json
import torch
7
from torch_spline_conv.functions.ffi import spline_basis_forward
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19

from .utils import tensors, Tensor

f = open(osp.join(osp.dirname(__file__), 'basis.json'), 'r')
data = json.load(f)
f.close()


@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
def test_spline_basis_cpu(tensor, i):
    degree = data[i].get('degree')
    pseudo = Tensor(tensor, data[i]['pseudo'])
rusty1s's avatar
rusty1s committed
20
    pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
    kernel_size = torch.LongTensor(data[i]['kernel_size'])
    is_open_spline = torch.ByteTensor(data[i]['is_open_spline'])
    K = kernel_size.prod()
    expected_basis = Tensor(tensor, data[i]['expected_basis'])
    expected_index = torch.ByteTensor(data[i]['expected_index'])

27
28
    basis, index = spline_basis_forward(degree, pseudo, kernel_size,
                                        is_open_spline, K)
rusty1s's avatar
rusty1s committed
29
30
31
    basis = [pytest.approx(x, 0.01) for x in basis.view(-1).tolist()]

    assert basis == expected_basis.view(-1).tolist()
rusty1s's avatar
rusty1s committed
32
    assert index.tolist() == expected_index.tolist()