conv.py 2.92 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
linting  
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
from .basis import SplineBasis
from .weighting import SplineWeighting
rusty1s's avatar
linting  
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
from .utils.degree import degree as node_degree
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
class SplineConv(object):
rusty1s's avatar
typo  
rusty1s committed
10
    """Applies the spline-based convolution operator :math:`(f \star g)(i) =
rusty1s's avatar
rusty1s committed
11
12
    \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
    f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
rusty1s's avatar
typo  
rusty1s committed
13
14
    The kernel function :math:`g_l` is defined over the weighted B-spline
    tensor product basis for a single input feature map :math:`l`.
rusty1s's avatar
rusty1s committed
15
16

    Args:
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28
29
        src (:class:`Tensor`): Input node features of shape
            (number_of_nodes x in_channels).
        edge_index (:class:`LongTensor`): Graph edges, given by source and
            target indices, of shape (2 x number_of_edges) in the fixed
            interval [0, 1].
        pseudo (:class:`Tensor`): Edge attributes, ie. pseudo coordinates,
            of shape (number_of_edges x number_of_edge_attributes).
        weight (:class:`Tensor`): Trainable weight parameters of shape
            (kernel_size x in_channels x out_channels).
        kernel_size (:class:`LongTensor`): Number of trainable weight
            parameters in each edge dimension.
        is_open_spline (:class:`ByteTensor`): Whether to use open or closed
            B-spline bases for each dimension.
rusty1s's avatar
rusty1s committed
30
        degree (int, optional): B-spline basis degree. (default: :obj:`1`)
rusty1s's avatar
rusty1s committed
31
        root_weight (:class:`Tensor`, optional): Additional shared trainable
rusty1s's avatar
rusty1s committed
32
            parameters for each feature of the root node of shape
rusty1s's avatar
rusty1s committed
33
34
35
36
37
            (in_channels x out_channels). (default: :obj:`None`)
        bias (:class:`Tensor`, optional): Optional bias of shape
            (out_channels). (default: :obj:`None`)

    :rtype: :class:`Tensor`
rusty1s's avatar
rusty1s committed
38
    """
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
    @staticmethod
rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
48
49
    def apply(src,
              edge_index,
              pseudo,
              weight,
              kernel_size,
              is_open_spline,
              degree=1,
              root_weight=None,
              bias=None):
rusty1s's avatar
rusty1s committed
50
51
52

        src = src.unsqueeze(-1) if src.dim() == 1 else src
        pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
55
        row, col = edge_index
        n, m_out = src.size(0), weight.size(2)
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
        # Weight each node.
rusty1s's avatar
rusty1s committed
58
59
        data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
        output = SplineWeighting.apply(src[col], weight, *data)
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
62
63
        # Convert e x m_out to n x m_out features.
        row_expand = row.unsqueeze(-1).expand_as(output)
        output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output)
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
66
67
        # Normalize output by node degree.
        deg = node_degree(row, n, out=src.new_empty(()))
        output /= deg.unsqueeze(-1).clamp(min=1)
rusty1s's avatar
rusty1s committed
68

rusty1s's avatar
rusty1s committed
69
70
71
        # Weight root node separately (if wished).
        if root_weight is not None:
            output += torch.mm(src, root_weight)
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
74
75
        # Add bias (if wished).
        if bias is not None:
            output += bias
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
        return output