spline_conv.py 1.51 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch.autograd import Variable as Var
rusty1s's avatar
rusty1s committed
3
4

from .degree import node_degree
5
from .spline_weighting import spline_weighting
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14


def spline_conv(x,
                edge_index,
                pseudo,
                weight,
                kernel_size,
                is_open_spline,
                degree=1,
rusty1s's avatar
rusty1s committed
15
                root_weight=None,
rusty1s's avatar
rusty1s committed
16
17
                bias=None):

rusty1s's avatar
rusty1s committed
18
19
20
    n, e, m_out = x.size(0), edge_index.size(1), weight.size(2)

    x = x.unsqueeze(-1) if x.dim() == 1 else x
rusty1s's avatar
rusty1s committed
21
    pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
    # Convolve over each node.
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
    output = spline_weighting(x[edge_index[1]], pseudo, weight, kernel_size,
                              is_open_spline, degree)

    # Perform the real convolution => Convert e x m_out to n x m_out features.
    row = edge_index[0].unsqueeze(-1).expand(e, m_out)
    row = row if torch.is_tensor(x) else Var(row)
    zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
    output = zero.fill_(0).scatter_add_(0, row, output)

rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    # Compute degree.
    degree = x.new() if torch.is_tensor(x) else x.data.new()
    degree = node_degree(edge_index, n, out=degree)

    # Normalize output by node degree.
    degree = degree.unsqueeze(-1).clamp_(min=1)
    output /= degree if torch.is_tensor(x) else Var(degree)

    # Weight root node separately (if wished).
    if root_weight is not None:
        output += torch.mm(x, root_weight)

    # Add bias (if wished).
    if bias is not None:
        output += bias

rusty1s's avatar
rusty1s committed
49
    return output