"megatron/vscode:/vscode.git/clone" did not exist on "bb618c02a2ae274e1a45a95a4dfb5aede3d3f93b"
conv.py 9.06 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
import dgl
import dgl.function as fn
3
4
5
6
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import SumPooling
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
8

9
10
11
12

### GIN convolution along the graph structure
class GINConv(nn.Module):
    def __init__(self, emb_dim):
13
14
15
        """
        emb_dim (int): node embedding dimensionality
        """
16
17
18

        super(GINConv, self).__init__()

19
20
21
22
23
24
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim),
        )
25
26
        self.eps = nn.Parameter(torch.Tensor([0]))

27
        self.bond_encoder = BondEncoder(emb_dim=emb_dim)
28
29
30
31

    def forward(self, g, x, edge_attr):
        with g.local_scope():
            edge_embedding = self.bond_encoder(edge_attr)
32
33
34
35
36
            g.ndata["x"] = x
            g.apply_edges(fn.copy_u("x", "m"))
            g.edata["m"] = F.relu(g.edata["m"] + edge_embedding)
            g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
            out = self.mlp((1 + self.eps) * x + g.ndata["new_x"])
37
38
39

            return out

40

41
42
43
### GCN convolution along the graph structure
class GCNConv(nn.Module):
    def __init__(self, emb_dim):
44
45
46
        """
        emb_dim (int): node embedding dimensionality
        """
47
48
49
50
51

        super(GCNConv, self).__init__()

        self.linear = nn.Linear(emb_dim, emb_dim)
        self.root_emb = nn.Embedding(1, emb_dim)
52
        self.bond_encoder = BondEncoder(emb_dim=emb_dim)
53
54
55
56
57
58
59
60
61

    def forward(self, g, x, edge_attr):
        with g.local_scope():
            x = self.linear(x)
            edge_embedding = self.bond_encoder(edge_attr)

            # Molecular graphs are undirected
            # g.out_degrees() is the same as g.in_degrees()
            degs = (g.out_degrees().float() + 1).to(x.device)
62
63
64
65
66
67
68
69
70
71
72
73
74
            norm = torch.pow(degs, -0.5).unsqueeze(-1)  # (N, 1)
            g.ndata["norm"] = norm
            g.apply_edges(fn.u_mul_v("norm", "norm", "norm"))

            g.ndata["x"] = x
            g.apply_edges(fn.copy_u("x", "m"))
            g.edata["m"] = g.edata["norm"] * F.relu(
                g.edata["m"] + edge_embedding
            )
            g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
            out = g.ndata["new_x"] + F.relu(
                x + self.root_emb.weight
            ) * 1.0 / degs.view(-1, 1)
75
76
77

            return out

78

79
80
81
82
83
84
### GNN to generate node embedding
class GNN_node(nn.Module):
    """
    Output:
        node representations
    """
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def __init__(
        self,
        num_layers,
        emb_dim,
        drop_ratio=0.5,
        JK="last",
        residual=False,
        gnn_type="gin",
    ):
        """
        num_layers (int): number of GNN message passing layers
        emb_dim (int): node embedding dimensionality
        """
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        super(GNN_node, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for layer in range(num_layers):
117
            if gnn_type == "gin":
118
                self.convs.append(GINConv(emb_dim))
119
            elif gnn_type == "gcn":
120
121
                self.convs.append(GCNConv(emb_dim))
            else:
122
                ValueError("Undefined GNN type called {}".format(gnn_type))
123
124
125
126
127
128
129
130
131
132
133

            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

    def forward(self, g, x, edge_attr):
        ### computing input node embedding
        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layers):
            h = self.convs[layer](g, h_list[layer], edge_attr)
            h = self.batch_norms[layer](h)

            if layer == self.num_layers - 1:
134
135
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
136
            else:
137
138
139
                h = F.dropout(
                    F.relu(h), self.drop_ratio, training=self.training
                )
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers):
                node_representation += h_list[layer]

        return node_representation


### Virtual GNN to generate node embedding
class GNN_node_Virtualnode(nn.Module):
    """
    Output:
        node representations
    """
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    def __init__(
        self,
        num_layers,
        emb_dim,
        drop_ratio=0.5,
        JK="last",
        residual=False,
        gnn_type="gin",
    ):
        """
        num_layers (int): number of GNN message passing layers
        emb_dim (int): node embedding dimensionality
        """
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

        super(GNN_node_Virtualnode, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ### set the initial virtual node embedding to 0.
        self.virtualnode_embedding = nn.Embedding(1, emb_dim)
        nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

        ### List of GNNs
        self.convs = nn.ModuleList()
        ### batch norms applied to node embeddings
        self.batch_norms = nn.ModuleList()

        ### List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = nn.ModuleList()

        for layer in range(num_layers):
203
            if gnn_type == "gin":
204
                self.convs.append(GINConv(emb_dim))
205
            elif gnn_type == "gcn":
206
207
                self.convs.append(GCNConv(emb_dim))
            else:
208
                ValueError("Undefined GNN type called {}".format(gnn_type))
209
210
211
212

            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

        for layer in range(num_layers - 1):
213
214
215
216
217
218
219
220
221
222
            self.mlp_virtualnode_list.append(
                nn.Sequential(
                    nn.Linear(emb_dim, emb_dim),
                    nn.BatchNorm1d(emb_dim),
                    nn.ReLU(),
                    nn.Linear(emb_dim, emb_dim),
                    nn.BatchNorm1d(emb_dim),
                    nn.ReLU(),
                )
            )
223
224
225
226
227
        self.pool = SumPooling()

    def forward(self, g, x, edge_attr):
        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(
228
229
            torch.zeros(g.batch_size).to(x.dtype).to(x.device)
        )
230
231

        h_list = [self.atom_encoder(x)]
232
233
234
        batch_id = dgl.broadcast_nodes(
            g, torch.arange(g.batch_size).to(x.device)
        )
235
236
237
238
239
240
241
242
        for layer in range(self.num_layers):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]

            ### Message passing among graph nodes
            h = self.convs[layer](g, h_list[layer], edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
243
244
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
245
            else:
246
247
248
                h = F.dropout(
                    F.relu(h), self.drop_ratio, training=self.training
                )
249
250
251
252
253
254
255
256
257

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layers - 1:
                ### add message from graph nodes to virtual nodes
258
259
260
                virtualnode_embedding_temp = (
                    self.pool(g, h_list[layer]) + virtualnode_embedding
                )
261
262
                ### transform virtual nodes using MLP
                virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](
263
264
                    virtualnode_embedding_temp
                )
265
266
267

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(
268
269
270
271
                        virtualnode_embedding_temp,
                        self.drop_ratio,
                        training=self.training,
                    )
272
273
                else:
                    virtualnode_embedding = F.dropout(
274
275
276
277
                        virtualnode_embedding_temp,
                        self.drop_ratio,
                        training=self.training,
                    )
278
279
280
281
282
283
284
285
286
287

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers):
                node_representation += h_list[layer]

        return node_representation