"vscode:/vscode.git/clone" did not exist on "beb49eef65acefc64a6ae0562ce58467e6974fde"
test_conv.py 2.75 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
from torch_spline_conv import SplineConv
rusty1s's avatar
rusty1s committed
7
from torch_spline_conv.basis import implemented_degrees as degrees
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
from .utils import dtypes, devices, tensor
rusty1s's avatar
rusty1s committed
10
11

tests = [{
rusty1s's avatar
rusty1s committed
12
    'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
    'pseudo': [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]],
    'weight': [
        [[0.5], [1]],
        [[1.5], [2]],
        [[2.5], [3]],
        [[3.5], [4]],
        [[4.5], [5]],
        [[5.5], [6]],
        [[6.5], [7]],
        [[7.5], [8]],
        [[8.5], [9]],
        [[9.5], [10]],
        [[10.5], [11]],
        [[11.5], [12]],
    ],
    'kernel_size': [3, 4],
    'is_open_spline': [1, 0],
    'root_weight': [[12.5], [13]],
    'bias': [1],
rusty1s's avatar
rename  
rusty1s committed
33
    'expected': [
rusty1s's avatar
rusty1s committed
34
35
36
37
38
        [1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
        [1 + 12.5 * 1 + 13 * 2],
        [1 + 12.5 * 3 + 13 * 4],
        [1 + 12.5 * 5 + 13 * 6],
        [1 + 12.5 * 7 + 13 * 8],
rusty1s's avatar
rusty1s committed
39
40
41
42
    ]
}]


rusty1s's avatar
rusty1s committed
43
44
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_conv_forward(test, dtype, device):
rusty1s's avatar
rusty1s committed
45
    x = tensor(test['x'], dtype, device)
rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
    edge_index = tensor(test['edge_index'], torch.long, device)
    pseudo = tensor(test['pseudo'], dtype, device)
    weight = tensor(test['weight'], dtype, device)
    kernel_size = tensor(test['kernel_size'], torch.long, device)
    is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
    root_weight = tensor(test['root_weight'], dtype, device)
    bias = tensor(test['bias'], dtype, device)
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
    out = SplineConv.apply(x, edge_index, pseudo, weight, kernel_size,
rusty1s's avatar
rusty1s committed
55
56
                           is_open_spline, 1, True, root_weight, bias)
    assert out.tolist() == test['expected']
rusty1s's avatar
rusty1s committed
57
58
59
60


@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
def test_spline_basis_backward(degree, device):
rusty1s's avatar
rusty1s committed
61
62
    x = torch.rand((3, 2), dtype=torch.double, device=device)
    x.requires_grad_()
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70
71
72
73
    edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
    pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
    pseudo.requires_grad_()
    weight = torch.rand((125, 2, 4), dtype=torch.double, device=device)
    weight.requires_grad_()
    kernel_size = tensor([5, 5, 5], torch.long, device)
    is_open_spline = tensor([1, 0, 1], torch.uint8, device)
    root_weight = torch.rand((2, 4), dtype=torch.double, device=device)
    root_weight.requires_grad_()
    bias = torch.rand((4), dtype=torch.double, device=device)
    bias.requires_grad_()
rusty1s's avatar
rusty1s committed
74

rusty1s's avatar
rusty1s committed
75
76
    data = (x, edge_index, pseudo, weight, kernel_size, is_open_spline, degree,
            True, root_weight, bias)
rusty1s's avatar
rusty1s committed
77
    assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True