conv.py 2.78 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
linting  
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
from .basis import spline_basis
from .weighting import spline_weighting
rusty1s's avatar
linting  
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
9
10
from .utils.new import new
from .utils.degree import node_degree


def spline_conv(src,
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
                edge_index,
                pseudo,
                weight,
                kernel_size,
                is_open_spline,
                degree=1,
                root_weight=None,
                bias=None):
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    """Applies the spline-based convolutional operator :math:`(f \star g)(i) =
    \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
    f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
    Here, :math:`g_l` denotes the kernel function defined over the weighted
    B-Spline tensor product basis for a single input feature map :math:`l`.

    Args:
        src (Tensor): Input node features of shape (number_of_nodes x
            in_channels)
        edge_idex (LongTensor): Graph edges, given by source and target
            indices, of shape (2 x number_of_edges)
        pseudo (Tensor): Edge attributes, ie. pseudo coordinates, of shape
            (number_of_edges x number_of_edge_attributes)
        weight (Tensor): Trainable weight parameters of shape (kernel_size x
            in_channels x out_channels)
        kernel_size (LongTensor): Number of trainable weight parameters in each
            edge dimension
        is_open_spline (ByteTensor): Whether to use open or closed B-spline
            bases for each dimension
        degree (int): B-spline basis degree (default: :obj:`1`)
        root_weight (Tensor): Additional shared trainable parameters for each
            feature of the root node of shape (in_channels x out_channels)
            (default: :obj:`None`)
        bias (Tensor): Optional bias of shape (out_channels) (default:
            :obj:`None`)
    """
rusty1s's avatar
rusty1s committed
45
46
47

    src = src.unsqueeze(-1) if src.dim() == 1 else src
    row, col = edge_index
rusty1s's avatar
linting  
rusty1s committed
48
    pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    n, e, m_out = src.size(0), row.size(0), weight.size(2)

    # Weight each node.
    basis, weight_index = spline_basis(degree, pseudo, kernel_size,
                                       is_open_spline)
    output = spline_weighting(src[col], weight, basis, weight_index)

    # Perform the real convolution => Convert e x m_out to n x m_out features.
    zero = new(src, n, m_out).fill_(0)
    row_expand = row.unsqueeze(-1).expand(e, m_out)
    output = zero.scatter_add_(0, row_expand, output)

    # Normalize output by node degree.
    degree = node_degree(row, n, out=new(src))
    output /= degree.unsqueeze(-1).clamp_(min=1)

    # Weight root node separately (if wished).
    if root_weight is not None:
        output += torch.mm(src, root_weight)

    # Add bias (if wished).
    if bias is not None:
        output += bias

    return output