wln.py 5.98 KB
Newer Older
1
"""WLN"""
2
# pylint: disable= no-member, arguments-differ, invalid-name
3
import math
4
import dgl.function as fn
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Parameter

__all__ = ['WLN']

class WLNLinear(nn.Module):
14
    r"""Linear layer for WLN
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

    Let stddev be

    .. math::
        \min(\frac{1.0}{\sqrt{in_feats}}, 0.1)

    The weight of the linear layer is initialized from a normal distribution
    with mean 0 and std as specified in stddev.

    Parameters
    ----------
    in_feats : int
        Size for the input.
    out_feats : int
        Size for the output.
    bias : bool
        Whether bias will be added to the output. Default to True.
    """
    def __init__(self, in_feats, out_feats, bias=True):
        super(WLNLinear, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.weight = Parameter(torch.Tensor(out_feats, in_feats))
        if bias:
            self.bias = Parameter(torch.Tensor(out_feats))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        """Initialize model parameters."""
        stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
        nn.init.normal_(self.weight, std=stddev)
        if self.bias is not None:
50
            nn.init.constant_(self.bias, 0.0)
51

52
    def forward(self, feats):
53
54
55
56
        """Applies the layer.

        Parameters
        ----------
57
        feats : float32 tensor of shape (N, *, in_feats)
58
59
60
61
62
63
64
            N for the number of samples, * for any additional dimensions.

        Returns
        -------
        float32 tensor of shape (N, *, out_feats)
            Result of the layer.
        """
65
        return F.linear(feats, self.weight, self.bias)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    def extra_repr(self):
        """Return a description of the layer."""
        return 'in_feats={}, out_feats={}, bias={}'.format(
            self.in_feats, self.out_feats, self.bias is not None
        )

class WLN(nn.Module):
    """Weisfeiler-Lehman Network (WLN)

    WLN is introduced in `Predicting Organic Reaction Outcomes with
    Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__.

    This class performs message passing and updates node representations.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    edge_in_feats : int
        Size for the input edge features.
    node_out_feats : int
        Size for the output node representations. Default to 300.
    n_layers : int
        Number of times for message passing. Note that same parameters
        are shared across n_layers message passing. Default to 3.
92
93
94
95
96
97
    project_in_feats : bool
        Whether to project input node features. If this is False, we expect node_in_feats
        to be the same as node_out_feats. Default to True.
    set_comparison : bool
        Whether to perform final node representation update mimicking
        set comparison. Default to True.
98
99
100
101
102
    """
    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=300,
103
104
105
                 n_layers=3,
                 project_in_feats=True,
                 set_comparison=True):
106
107
108
        super(WLN, self).__init__()

        self.n_layers = n_layers
109
110
111
112
113
114
115
116
117
118
119
        self.project_in_feats = project_in_feats
        if project_in_feats:
            self.project_node_in_feats = nn.Sequential(
                WLNLinear(node_in_feats, node_out_feats, bias=False),
                nn.ReLU()
            )
        else:
            assert node_in_feats == node_out_feats, \
                'Expect input node features to have the same size as that of output ' \
                'node features, got {:d} and {:d}'.format(node_in_feats, node_out_feats)

120
121
122
123
124
125
126
127
        self.project_concatenated_messages = nn.Sequential(
            WLNLinear(edge_in_feats + node_out_feats, node_out_feats),
            nn.ReLU()
        )
        self.get_new_node_feats = nn.Sequential(
            WLNLinear(2 * node_out_feats, node_out_feats),
            nn.ReLU()
        )
128
129
130
131
132
        self.set_comparison = set_comparison
        if set_comparison:
            self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False)
            self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False)
            self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    def forward(self, g, node_feats, edge_feats):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges.

        Returns
        -------
        float32 tensor of shape (V, node_out_feats)
            Updated node representations.
        """
151
152
        if self.project_in_feats:
            node_feats = self.project_node_in_feats(node_feats)
153
        for _ in range(self.n_layers):
154
155
156
157
158
159
            g = g.local_var()
            g.ndata['hv'] = node_feats
            g.apply_edges(fn.copy_src('hv', 'he_src'))
            concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
            g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
            g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))
Mufei Li's avatar
Mufei Li committed
160
161
            node_feats = self.get_new_node_feats(
                torch.cat([node_feats, g.ndata['hv_new']], dim=1))
162

163
164
165
166
167
168
169
170
171
        if not self.set_comparison:
            return node_feats
        else:
            g = g.local_var()
            g.ndata['hv'] = self.project_node_messages(node_feats)
            g.edata['he'] = self.project_edge_messages(edge_feats)
            g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr'))
            h_self = self.project_self(node_feats)  # (V, node_out_feats)
            return g.ndata['h_nbr'] * h_self