test_conv.py 2.91 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
from torch_spline_conv import spline_conv
rusty1s's avatar
update  
rusty1s committed
7
from torch_spline_conv.testing import devices, dtypes, tensor
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
degrees = [1, 2, 3]

rusty1s's avatar
rusty1s committed
11
tests = [{
rusty1s's avatar
rusty1s committed
12
    'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
    'pseudo': [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]],
    'weight': [
        [[0.5], [1]],
        [[1.5], [2]],
        [[2.5], [3]],
        [[3.5], [4]],
        [[4.5], [5]],
        [[5.5], [6]],
        [[6.5], [7]],
        [[7.5], [8]],
        [[8.5], [9]],
        [[9.5], [10]],
        [[10.5], [11]],
        [[11.5], [12]],
    ],
    'kernel_size': [3, 4],
    'is_open_spline': [1, 0],
    'root_weight': [[12.5], [13]],
    'bias': [1],
rusty1s's avatar
rename  
rusty1s committed
33
    'expected': [
rusty1s's avatar
rusty1s committed
34
35
36
37
38
        [1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
        [1 + 12.5 * 1 + 13 * 2],
        [1 + 12.5 * 3 + 13 * 4],
        [1 + 12.5 * 5 + 13 * 6],
        [1 + 12.5 * 7 + 13 * 8],
rusty1s's avatar
rusty1s committed
39
40
41
42
    ]
}]


rusty1s's avatar
rusty1s committed
43
44
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_conv_forward(test, dtype, device):
rusty1s's avatar
update  
rusty1s committed
45
46
47
    if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
        return

rusty1s's avatar
rusty1s committed
48
    x = tensor(test['x'], dtype, device)
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
    edge_index = tensor(test['edge_index'], torch.long, device)
    pseudo = tensor(test['pseudo'], dtype, device)
    weight = tensor(test['weight'], dtype, device)
    kernel_size = tensor(test['kernel_size'], torch.long, device)
    is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
    root_weight = tensor(test['root_weight'], dtype, device)
    bias = tensor(test['bias'], dtype, device)
rusty1s's avatar
update  
rusty1s committed
56
    expected = tensor(test['expected'], dtype, device)
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
    out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
                      is_open_spline, 1, True, root_weight, bias)
rusty1s's avatar
update  
rusty1s committed
60
61
62

    error = 1e-2 if dtype == torch.bfloat16 else 1e-7
    assert torch.allclose(out, expected, rtol=error, atol=error)
rusty1s's avatar
rusty1s committed
63
64


rusty1s's avatar
rusty1s committed
65
@pytest.mark.parametrize('degree,device', product(degrees, devices))
rusty1s's avatar
rusty1s committed
66
def test_spline_conv_backward(degree, device):
rusty1s's avatar
rusty1s committed
67
68
    x = torch.rand((3, 2), dtype=torch.double, device=device)
    x.requires_grad_()
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78
79
    edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
    pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
    pseudo.requires_grad_()
    weight = torch.rand((125, 2, 4), dtype=torch.double, device=device)
    weight.requires_grad_()
    kernel_size = tensor([5, 5, 5], torch.long, device)
    is_open_spline = tensor([1, 0, 1], torch.uint8, device)
    root_weight = torch.rand((2, 4), dtype=torch.double, device=device)
    root_weight.requires_grad_()
    bias = torch.rand((4), dtype=torch.double, device=device)
    bias.requires_grad_()
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
82
    data = (x, edge_index, pseudo, weight, kernel_size, is_open_spline, degree,
            True, root_weight, bias)
rusty1s's avatar
rusty1s committed
83
    assert gradcheck(spline_conv, data, eps=1e-6, atol=1e-4) is True