spline_conv.py 1.51 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
# from torch.autograd import Variable as Var

from .degree import node_degree
rusty1s's avatar
rusty1s committed
5
from .utils import spline_basis, spline_weighting
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


def spline_conv(x,
                index,
                pseudo,
                weight,
                kernel_size,
                is_open_spline,
                root_weight=None,
                degree=1,
                bias=None):

    x = x.unsqueeze(-1) if x.dim() == 1 else x

    # Get features for every target node => |E| x M_in
    output = x[index[1]]

    # Get B-spline basis products and weight indices for each edge.
rusty1s's avatar
rusty1s committed
24
25
    basis, weight_index = spline_basis(degree, pseudo, kernel_size,
                                       is_open_spline, weight.size(0))
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    # Weight gathered features based on B-spline basis and trainable weights.
    output = spline_weighting(output, weight, basis, weight_index)

    # Perform the real convolution => Convert |E| x M_out to N x M_out output.
    row = index[0].unsqueeze(-1).expand(-1, output.size(1))
    # zero = x if torch.is_tensor(x) else x.data
    zero = x.new(row.size()).fill_(0)
    # row, zero = row, zero if torch.is_tensor(x) else Var(row), Var(zero)
    output = zero.scatter_add_(0, row, output)

    # Normalize output by node degree.
    output /= node_degree(index, out=x.new()).unsqueeze(-1).clamp_(min=1)

    # Weight root node separately (if wished).
    if root_weight is not None:
        output += torch.mm(x, root_weight)

    # Add bias (if wished).
    if bias is not None:
        output += bias

    return output