conv.py 9.07 KB
Newer Older
1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
4
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
5

6
7
import dgl
import dgl.function as fn
8
from dgl.nn.pytorch import SumPooling
9

10
11
12
13

### GIN convolution along the graph structure
class GINConv(nn.Module):
    def __init__(self, emb_dim):
14
15
16
        """
        emb_dim (int): node embedding dimensionality
        """
17
18
19

        super(GINConv, self).__init__()

20
21
22
23
24
25
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim),
        )
26
27
        self.eps = nn.Parameter(torch.Tensor([0]))

28
        self.bond_encoder = BondEncoder(emb_dim=emb_dim)
29
30
31
32

    def forward(self, g, x, edge_attr):
        with g.local_scope():
            edge_embedding = self.bond_encoder(edge_attr)
33
34
35
36
37
            g.ndata["x"] = x
            g.apply_edges(fn.copy_u("x", "m"))
            g.edata["m"] = F.relu(g.edata["m"] + edge_embedding)
            g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
            out = self.mlp((1 + self.eps) * x + g.ndata["new_x"])
38
39
40

            return out

41

42
43
44
### GCN convolution along the graph structure
class GCNConv(nn.Module):
    def __init__(self, emb_dim):
45
46
47
        """
        emb_dim (int): node embedding dimensionality
        """
48
49
50
51
52

        super(GCNConv, self).__init__()

        self.linear = nn.Linear(emb_dim, emb_dim)
        self.root_emb = nn.Embedding(1, emb_dim)
53
        self.bond_encoder = BondEncoder(emb_dim=emb_dim)
54
55
56
57
58
59
60
61
62

    def forward(self, g, x, edge_attr):
        with g.local_scope():
            x = self.linear(x)
            edge_embedding = self.bond_encoder(edge_attr)

            # Molecular graphs are undirected
            # g.out_degrees() is the same as g.in_degrees()
            degs = (g.out_degrees().float() + 1).to(x.device)
63
64
65
66
67
68
69
70
71
72
73
74
75
            norm = torch.pow(degs, -0.5).unsqueeze(-1)  # (N, 1)
            g.ndata["norm"] = norm
            g.apply_edges(fn.u_mul_v("norm", "norm", "norm"))

            g.ndata["x"] = x
            g.apply_edges(fn.copy_u("x", "m"))
            g.edata["m"] = g.edata["norm"] * F.relu(
                g.edata["m"] + edge_embedding
            )
            g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
            out = g.ndata["new_x"] + F.relu(
                x + self.root_emb.weight
            ) * 1.0 / degs.view(-1, 1)
76
77
78

            return out

79

80
81
82
83
84
85
### GNN to generate node embedding
class GNN_node(nn.Module):
    """
    Output:
        node representations
    """
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    def __init__(
        self,
        num_layers,
        emb_dim,
        drop_ratio=0.5,
        JK="last",
        residual=False,
        gnn_type="gin",
    ):
        """
        num_layers (int): number of GNN message passing layers
        emb_dim (int): node embedding dimensionality
        """
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        super(GNN_node, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for layer in range(num_layers):
118
            if gnn_type == "gin":
119
                self.convs.append(GINConv(emb_dim))
120
            elif gnn_type == "gcn":
121
122
                self.convs.append(GCNConv(emb_dim))
            else:
123
                ValueError("Undefined GNN type called {}".format(gnn_type))
124
125
126
127
128
129
130
131
132
133
134
135

            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

    def forward(self, g, x, edge_attr):
        ### computing input node embedding
        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layers):

            h = self.convs[layer](g, h_list[layer], edge_attr)
            h = self.batch_norms[layer](h)

            if layer == self.num_layers - 1:
136
137
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
138
            else:
139
140
141
                h = F.dropout(
                    F.relu(h), self.drop_ratio, training=self.training
                )
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers):
                node_representation += h_list[layer]

        return node_representation


### Virtual GNN to generate node embedding
class GNN_node_Virtualnode(nn.Module):
    """
    Output:
        node representations
    """
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def __init__(
        self,
        num_layers,
        emb_dim,
        drop_ratio=0.5,
        JK="last",
        residual=False,
        gnn_type="gin",
    ):
        """
        num_layers (int): number of GNN message passing layers
        emb_dim (int): node embedding dimensionality
        """
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        super(GNN_node_Virtualnode, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ### set the initial virtual node embedding to 0.
        self.virtualnode_embedding = nn.Embedding(1, emb_dim)
        nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

        ### List of GNNs
        self.convs = nn.ModuleList()
        ### batch norms applied to node embeddings
        self.batch_norms = nn.ModuleList()

        ### List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = nn.ModuleList()

        for layer in range(num_layers):
205
            if gnn_type == "gin":
206
                self.convs.append(GINConv(emb_dim))
207
            elif gnn_type == "gcn":
208
209
                self.convs.append(GCNConv(emb_dim))
            else:
210
                ValueError("Undefined GNN type called {}".format(gnn_type))
211
212
213
214

            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

        for layer in range(num_layers - 1):
215
216
217
218
219
220
221
222
223
224
            self.mlp_virtualnode_list.append(
                nn.Sequential(
                    nn.Linear(emb_dim, emb_dim),
                    nn.BatchNorm1d(emb_dim),
                    nn.ReLU(),
                    nn.Linear(emb_dim, emb_dim),
                    nn.BatchNorm1d(emb_dim),
                    nn.ReLU(),
                )
            )
225
226
227
228
229
        self.pool = SumPooling()

    def forward(self, g, x, edge_attr):
        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(
230
231
            torch.zeros(g.batch_size).to(x.dtype).to(x.device)
        )
232
233

        h_list = [self.atom_encoder(x)]
234
235
236
        batch_id = dgl.broadcast_nodes(
            g, torch.arange(g.batch_size).to(x.device)
        )
237
238
239
240
241
242
243
244
        for layer in range(self.num_layers):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]

            ### Message passing among graph nodes
            h = self.convs[layer](g, h_list[layer], edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
245
246
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
247
            else:
248
249
250
                h = F.dropout(
                    F.relu(h), self.drop_ratio, training=self.training
                )
251
252
253
254
255
256
257
258
259

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layers - 1:
                ### add message from graph nodes to virtual nodes
260
261
262
                virtualnode_embedding_temp = (
                    self.pool(g, h_list[layer]) + virtualnode_embedding
                )
263
264
                ### transform virtual nodes using MLP
                virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](
265
266
                    virtualnode_embedding_temp
                )
267
268
269

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(
270
271
272
273
                        virtualnode_embedding_temp,
                        self.drop_ratio,
                        training=self.training,
                    )
274
275
                else:
                    virtualnode_embedding = F.dropout(
276
277
278
279
                        virtualnode_embedding_temp,
                        self.drop_ratio,
                        training=self.training,
                    )
280
281
282
283
284
285
286
287
288
289

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers):
                node_representation += h_list[layer]

        return node_representation