wln.py 5.05 KB
Newer Older
1
"""WLN"""
2
# pylint: disable= no-member, arguments-differ, invalid-name
3
import math
4
import dgl.function as fn
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Parameter

__all__ = ['WLN']

class WLNLinear(nn.Module):
14
    r"""Linear layer for WLN
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    Let stddev be

    .. math::
        \min(\frac{1.0}{\sqrt{in_feats}}, 0.1)

    The weight of the linear layer is initialized from a normal distribution
    with mean 0 and std as specified in stddev.

    Parameters
    ----------
    in_feats : int
        Size for the input.
    out_feats : int
        Size for the output.
    bias : bool
        Whether bias will be added to the output. Default to True.
    """
    def __init__(self, in_feats, out_feats, bias=True):
        super(WLNLinear, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.weight = Parameter(torch.Tensor(out_feats, in_feats))
        if bias:
            self.bias = Parameter(torch.Tensor(out_feats))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        """Initialize model parameters."""
        stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
        nn.init.normal_(self.weight, std=stddev)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

54
    def forward(self, feats):
55
56
57
58
        """Applies the layer.

        Parameters
        ----------
59
        feats : float32 tensor of shape (N, *, in_feats)
60
61
62
63
64
65
66
            N for the number of samples, * for any additional dimensions.

        Returns
        -------
        float32 tensor of shape (N, *, out_feats)
            Result of the layer.
        """
67
        return F.linear(feats, self.weight, self.bias)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    def extra_repr(self):
        """Return a description of the layer."""
        return 'in_feats={}, out_feats={}, bias={}'.format(
            self.in_feats, self.out_feats, self.bias is not None
        )

class WLN(nn.Module):
    """Weisfeiler-Lehman Network (WLN)

    WLN is introduced in `Predicting Organic Reaction Outcomes with
    Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.

    This class performs message passing and updates node representations.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    edge_in_feats : int
        Size for the input edge features.
    node_out_feats : int
        Size for the output node representations. Default to 300.
    n_layers : int
        Number of times for message passing. Note that same parameters
        are shared across n_layers message passing. Default to 3.
    """
    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=300,
                 n_layers=3):
        super(WLN, self).__init__()

        self.n_layers = n_layers
        self.project_node_in_feats = nn.Sequential(
            WLNLinear(node_in_feats, node_out_feats, bias=False),
            nn.ReLU()
        )
        self.project_concatenated_messages = nn.Sequential(
            WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
            nn.ReLU()
        )
        self.get_new_node_feats = nn.Sequential(
            WLNLinear(2 * node_out_feats, node_out_feats),
            nn.ReLU()
        )
        self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
        self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
        self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)

    def forward(self, g, node_feats, edge_feats):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges.

        Returns
        -------
        float32 tensor of shape (V, node_out_feats)
            Updated node representations.
        """
        node_feats = self.project_node_in_feats(node_feats)
137
        for _ in range(self.n_layers):
138
139
140
141
142
143
144
145
146
147
148
149
150
            g = g.local_var()
            g.ndata['hv'] = node_feats
            g.apply_edges(fn.copy_src('hv', 'he_src'))
            concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
            g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
            g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))

        g = g.local_var()
        g.ndata['hv'] = self.project_node_messages(node_feats)
        g.edata['he'] = self.project_edge_messages(edge_feats)
        g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
        h_self = self.project_self(node_feats)  # (V, node_out_feats)
        return g.ndata['h_nbr'] * h_self