conv.py 3.05 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch.autograd import Variable
rusty1s's avatar
linting  
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
from .basis import spline_basis
from .weighting import spline_weighting
rusty1s's avatar
linting  
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
8
9
10
11
from .utils.new import new
from .utils.degree import node_degree


def spline_conv(src,
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
18
19
                edge_index,
                pseudo,
                weight,
                kernel_size,
                is_open_spline,
                degree=1,
                root_weight=None,
                bias=None):
rusty1s's avatar
typo  
rusty1s committed
20
    """Applies the spline-based convolution operator :math:`(f \star g)(i) =
rusty1s's avatar
rusty1s committed
21
22
23
    \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
    f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
    Here, :math:`g_l` denotes the kernel function defined over the weighted
rusty1s's avatar
bugfix  
rusty1s committed
24
    B-spline tensor product basis for a single input feature map :math:`l`.
rusty1s's avatar
rusty1s committed
25
26

    Args:
rusty1s's avatar
rusty1s committed
27
28
        src (Tensor or Variable): Input node features of shape
            (number_of_nodes x in_channels)
rusty1s's avatar
typos  
rusty1s committed
29
30
31
        edge_index (LongTensor): Graph edges, given by source and target
            indices, of shape (2 x number_of_edges) in the fixed interval
            [0, 1]
rusty1s's avatar
rusty1s committed
32
33
34
35
        pseudo (Tensor or Variable): Edge attributes, ie. pseudo coordinates,
            of shape (number_of_edges x number_of_edge_attributes)
        weight (Tensor or Variable): Trainable weight parameters of shape
            (kernel_size x in_channels x out_channels)
rusty1s's avatar
rusty1s committed
36
37
38
39
40
        kernel_size (LongTensor): Number of trainable weight parameters in each
            edge dimension
        is_open_spline (ByteTensor): Whether to use open or closed B-spline
            bases for each dimension
        degree (int): B-spline basis degree (default: :obj:`1`)
rusty1s's avatar
rusty1s committed
41
42
43
44
        root_weight (Tensor or Variable): Additional shared trainable
            parameters for each feature of the root node of shape
            (in_channels x out_channels) (default: :obj:`None`)
        bias (Tensor or Variable): Optional bias of shape (out_channels)
rusty1s's avatar
rusty1s committed
45
46
            (default: :obj:`None`)
    """
rusty1s's avatar
rusty1s committed
47
48
49

    src = src.unsqueeze(-1) if src.dim() == 1 else src
    row, col = edge_index
rusty1s's avatar
linting  
rusty1s committed
50
    pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
61

    n, e, m_out = src.size(0), row.size(0), weight.size(2)

    # Weight each node.
    basis, weight_index = spline_basis(degree, pseudo, kernel_size,
                                       is_open_spline)
    output = spline_weighting(src[col], weight, basis, weight_index)

    # Perform the real convolution => Convert e x m_out to n x m_out features.
    zero = new(src, n, m_out).fill_(0)
    row_expand = row.unsqueeze(-1).expand(e, m_out)
rusty1s's avatar
rusty1s committed
62
    row_expand = row_expand if torch.is_tensor(src) else Variable(row_expand)
rusty1s's avatar
rusty1s committed
63
64
65
    output = zero.scatter_add_(0, row_expand, output)

    # Normalize output by node degree.
rusty1s's avatar
rusty1s committed
66
67
    index = row if torch.is_tensor(src) else Variable(row)
    degree = node_degree(index, n, out=new(src))
rusty1s's avatar
bugfix  
rusty1s committed
68
    output /= degree.unsqueeze(-1).clamp(min=1)
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78

    # Weight root node separately (if wished).
    if root_weight is not None:
        output += torch.mm(src, root_weight)

    # Add bias (if wished).
    if bias is not None:
        output += bias

    return output