test_conv.py 2.69 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
from torch_spline_conv import spline_conv
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
from .utils import dtypes, devices, tensor
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
degrees = [1, 2, 3]

rusty1s's avatar
rusty1s committed
12
tests = [{
rusty1s's avatar
rusty1s committed
13
    'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
    'pseudo': [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]],
    'weight': [
        [[0.5], [1]],
        [[1.5], [2]],
        [[2.5], [3]],
        [[3.5], [4]],
        [[4.5], [5]],
        [[5.5], [6]],
        [[6.5], [7]],
        [[7.5], [8]],
        [[8.5], [9]],
        [[9.5], [10]],
        [[10.5], [11]],
        [[11.5], [12]],
    ],
    'kernel_size': [3, 4],
    'is_open_spline': [1, 0],
    'root_weight': [[12.5], [13]],
    'bias': [1],
rusty1s's avatar
rename  
rusty1s committed
34
    'expected': [
rusty1s's avatar
rusty1s committed
35
36
37
38
39
        [1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
        [1 + 12.5 * 1 + 13 * 2],
        [1 + 12.5 * 3 + 13 * 4],
        [1 + 12.5 * 5 + 13 * 6],
        [1 + 12.5 * 7 + 13 * 8],
rusty1s's avatar
rusty1s committed
40
41
42
43
    ]
}]


rusty1s's avatar
rusty1s committed
44
45
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_conv_forward(test, dtype, device):
rusty1s's avatar
rusty1s committed
46
    x = tensor(test['x'], dtype, device)
rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
53
    edge_index = tensor(test['edge_index'], torch.long, device)
    pseudo = tensor(test['pseudo'], dtype, device)
    weight = tensor(test['weight'], dtype, device)
    kernel_size = tensor(test['kernel_size'], torch.long, device)
    is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
    root_weight = tensor(test['root_weight'], dtype, device)
    bias = tensor(test['bias'], dtype, device)
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
56
    out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
                      is_open_spline, 1, True, root_weight, bias)
rusty1s's avatar
rusty1s committed
57
    assert out.tolist() == test['expected']
rusty1s's avatar
rusty1s committed
58
59


rusty1s's avatar
rusty1s committed
60
@pytest.mark.parametrize('degree,device', product(degrees, devices))
rusty1s's avatar
rusty1s committed
61
def test_spline_basis_backward(degree, device):
rusty1s's avatar
rusty1s committed
62
63
    x = torch.rand((3, 2), dtype=torch.double, device=device)
    x.requires_grad_()
rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
74
    edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
    pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
    pseudo.requires_grad_()
    weight = torch.rand((125, 2, 4), dtype=torch.double, device=device)
    weight.requires_grad_()
    kernel_size = tensor([5, 5, 5], torch.long, device)
    is_open_spline = tensor([1, 0, 1], torch.uint8, device)
    root_weight = torch.rand((2, 4), dtype=torch.double, device=device)
    root_weight.requires_grad_()
    bias = torch.rand((4), dtype=torch.double, device=device)
    bias.requires_grad_()
rusty1s's avatar
rusty1s committed
75

rusty1s's avatar
rusty1s committed
76
77
    data = (x, edge_index, pseudo, weight, kernel_size, is_open_spline, degree,
            True, root_weight, bias)
rusty1s's avatar
rusty1s committed
78
    assert gradcheck(spline_conv, data, eps=1e-6, atol=1e-4) is True