spline_conv.py 2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch.autograd import Variable as Var
rusty1s's avatar
rusty1s committed
3
4

from .degree import node_degree
rusty1s's avatar
rusty1s committed
5
from .utils import spline_basis, spline_weighting
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
def _spline_conv(x,
                 edge_index,
                 pseudo,
                 weight,
                 kernel_size,
                 is_open_spline,
                 degree=1):
rusty1s's avatar
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
    n, e = x.size(0), edge_index.size(1)
    K, m_in, m_out = weight.size()

rusty1s's avatar
rusty1s committed
19
20
21
    x = x.unsqueeze(-1) if x.dim() == 1 else x

    # Get features for every target node => |E| x M_in
rusty1s's avatar
rusty1s committed
22
    output = x[edge_index[1]]
rusty1s's avatar
rusty1s committed
23
24

    # Get B-spline basis products and weight indices for each edge.
rusty1s's avatar
rusty1s committed
25
    basis, weight_index = spline_basis(degree, pseudo, kernel_size,
rusty1s's avatar
rusty1s committed
26
                                       is_open_spline, K)
rusty1s's avatar
rusty1s committed
27
28
29
30
31

    # Weight gathered features based on B-spline basis and trainable weights.
    output = spline_weighting(output, weight, basis, weight_index)

    # Perform the real convolution => Convert |E| x M_out to N x M_out output.
rusty1s's avatar
rusty1s committed
32
    row = edge_index[0].unsqueeze(-1).expand(e, m_out)
rusty1s's avatar
rusty1s committed
33
34
    row = row if torch.is_tensor(x) else Var(row)
    zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    return zero.fill_(0).scatter_add_(0, row, output)


def spline_conv(x,
                edge_index,
                pseudo,
                weight,
                kernel_size,
                is_open_spline,
                root_weight=None,
                degree=1,
                bias=None):

    n = x.size(0)
    output = _spline_conv(x, edge_index, pseudo, weight, kernel_size,
                          is_open_spline, degree)
rusty1s's avatar
rusty1s committed
51
52

    # Normalize output by node degree.
rusty1s's avatar
rusty1s committed
53
54
55
    degree = x.new() if torch.is_tensor(x) else x.data.new()
    degree = node_degree(edge_index, n, out=degree).unsqueeze(-1).clamp_(min=1)
    output /= degree if torch.is_tensor(x) else Var(degree)
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65

    # Weight root node separately (if wished).
    if root_weight is not None:
        output += torch.mm(x, root_weight)

    # Add bias (if wished).
    if bias is not None:
        output += bias

    return output