model.py 7.61 KB
Newer Older
1
2
3
import torch as th
import torch.nn as nn
import torch.nn.functional as F
4
5
from torch.nn import GRU, BatchNorm1d, Linear, ModuleList, ReLU, Sequential
from utils import global_global_loss_, local_global_loss_
6
7
8
9

from dgl.nn import GINConv, NNConv, Set2Set
from dgl.nn.pytorch.glob import SumPooling

10
""" Feedforward neural network"""
11
12
13
14


class FeedforwardNetwork(nn.Module):

15
    """
16
17
18
19
20
21
22
23
24
25
26
27
28
    3-layer feed-forward neural networks with jumping connections
    Parameters
    -----------
    in_dim: int
        Input feature size.
    hid_dim: int
        Hidden feature size.

    Functions
    -----------
    forward(feat):
        feat: Tensor
            [N * D], input features
29
    """
30
31
32
33

    def __init__(self, in_dim, hid_dim):
        super(FeedforwardNetwork, self).__init__()

34
35
36
37
38
39
40
41
        self.block = Sequential(
            Linear(in_dim, hid_dim),
            ReLU(),
            Linear(hid_dim, hid_dim),
            ReLU(),
            Linear(hid_dim, hid_dim),
            ReLU(),
        )
42
43
44
45
46
47
48
49
50
51
52
53

        self.jump_con = Linear(in_dim, hid_dim)

    def forward(self, feat):
        block_out = self.block(feat)
        jump_out = self.jump_con(feat)

        out = block_out + jump_out

        return out


54
55
""" Unsupervised Setting """

56
57

class GINEncoder(nn.Module):
58
    """
59
60
61
62
63
64
65
    Encoder based on dgl.nn.GINConv &  dgl.nn.SumPooling
    Parameters
    -----------
    in_dim: int
        Input feature size.
    hid_dim: int
        Hidden feature size.
66
    n_layer:
67
68
69
70
71
72
73
74
        Number of GIN layers.

    Functions
    -----------
    forward(graph, feat):
        graph: DGLGraph
        feat: Tensor
            [N * D], node features
75
    """
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    def __init__(self, in_dim, hid_dim, n_layer):
        super(GINEncoder, self).__init__()

        self.n_layer = n_layer

        self.convs = ModuleList()
        self.bns = ModuleList()

        for i in range(n_layer):
            if i == 0:
                n_in = in_dim
            else:
                n_in = hid_dim
            n_out = hid_dim
91
92
93
            block = Sequential(
                Linear(n_in, n_out), ReLU(), Linear(hid_dim, hid_dim)
            )
94

95
            conv = GINConv(apply_func=block, aggregator_type="sum")
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            bn = BatchNorm1d(hid_dim)

            self.convs.append(conv)
            self.bns.append(bn)

        # sum pooling
        self.pool = SumPooling()

    def forward(self, graph, feat):

        xs = []
        x = feat
        for i in range(self.n_layer):
            x = F.relu(self.convs[i](graph, x))
            x = self.bns[i](x)
            xs.append(x)

113
114
        local_emb = th.cat(xs, 1)  # patch-level embedding
        global_emb = self.pool(graph, local_emb)  # graph-level embedding
115
116
117
118
119
120
121
122
123
124
125
126
127
128

        return global_emb, local_emb


class InfoGraph(nn.Module):
    r"""
        InfoGraph model for unsupervised setting

    Parameters
    -----------
    in_dim: int
        Input feature size.
    hid_dim: int
        Hidden feature size.
129
    n_layer: int
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        Number of the GNN encoder layers.

    Functions
    -----------
    forward(graph):
        graph: DGLGraph

    """

    def __init__(self, in_dim, hid_dim, n_layer):
        super(InfoGraph, self).__init__()

        self.in_dim = in_dim
        self.hid_dim = hid_dim

        self.n_layer = n_layer
        embedding_dim = hid_dim * n_layer

        self.encoder = GINEncoder(in_dim, hid_dim, n_layer)

150
151
152
153
154
155
        self.local_d = FeedforwardNetwork(
            embedding_dim, embedding_dim
        )  # local discriminator (node-level)
        self.global_d = FeedforwardNetwork(
            embedding_dim, embedding_dim
        )  # global discriminator (graph-level)
156
157
158
159
160
161
162
163
164
165
166
167
168

    def get_embedding(self, graph, feat):
        # get_embedding function for evaluation the learned embeddings

        with th.no_grad():
            global_emb, _ = self.encoder(graph, feat)

        return global_emb

    def forward(self, graph, feat, graph_id):

        global_emb, local_emb = self.encoder(graph, feat)

169
170
        global_h = self.global_d(global_emb)  # global hidden representation
        local_h = self.local_d(local_emb)  # local hidden representation
171
172
173
174
175
176

        loss = local_global_loss_(local_h, global_h, graph_id)

        return loss


177
178
""" Semisupervised Setting """

179
180
181

class NNConvEncoder(nn.Module):

182
    """
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    Encoder based on dgl.nn.NNConv & GRU & dgl.nn.set2set pooling
    Parameters
    -----------
    in_dim: int
        Input feature size.
    hid_dim: int
        Hidden feature size.

    Functions
    -----------
    forward(graph, nfeat, efeat):
        graph: DGLGraph
        nfeat: Tensor
            [N * D1], node features
        efeat: Tensor
            [E * D2], edge features
199
    """
200
201
202
203
204
205
206

    def __init__(self, in_dim, hid_dim):
        super(NNConvEncoder, self).__init__()

        self.lin0 = Linear(in_dim, hid_dim)

        # mlp for edge convolution in NNConv
207
208
209
210
211
212
213
214
215
216
217
        block = Sequential(
            Linear(5, 128), ReLU(), Linear(128, hid_dim * hid_dim)
        )

        self.conv = NNConv(
            hid_dim,
            hid_dim,
            edge_func=block,
            aggregator_type="mean",
            residual=False,
        )
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        self.gru = GRU(hid_dim, hid_dim)

        # set2set pooling
        self.set2set = Set2Set(hid_dim, n_iters=3, n_layers=1)

    def forward(self, graph, nfeat, efeat):

        out = F.relu(self.lin0(nfeat))
        h = out.unsqueeze(0)

        feat_map = []

        # Convolution layer number is 3
        for i in range(3):
            m = F.relu(self.conv(graph, out, efeat))
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)
            feat_map.append(out)

        out = self.set2set(graph, out)

        # out: global embedding, feat_map[-1]: local embedding
        return out, feat_map[-1]


class InfoGraphS(nn.Module):

245
    """
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    InfoGraph* model for semi-supervised setting
    Parameters
    -----------
    in_dim: int
        Input feature size.
    hid_dim: int
        Hidden feature size.

    Functions
    -----------
    forward(graph):
        graph: DGLGraph

    unsupforward(graph):
        graph: DGLGraph
261
262

    """
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

    def __init__(self, in_dim, hid_dim):
        super(InfoGraphS, self).__init__()

        self.sup_encoder = NNConvEncoder(in_dim, hid_dim)
        self.unsup_encoder = NNConvEncoder(in_dim, hid_dim)

        self.fc1 = Linear(2 * hid_dim, hid_dim)
        self.fc2 = Linear(hid_dim, 1)

        # unsupervised local discriminator and global discriminator for local-global infomax
        self.unsup_local_d = FeedforwardNetwork(hid_dim, hid_dim)
        self.unsup_global_d = FeedforwardNetwork(2 * hid_dim, hid_dim)

        # supervised global discriminator and unsupervised global discriminator for global-global infomax
        self.sup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
        self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)

    def forward(self, graph, nfeat, efeat):
282

283
284
285
286
287
288
        sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)

        sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb)))
        sup_global_pred = sup_global_pred.view(-1)

        return sup_global_pred
289

290
291
292
    def unsup_forward(self, graph, nfeat, efeat, graph_id):

        sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
293
294
295
296
        unsup_global_emb, unsup_local_emb = self.unsup_encoder(
            graph, nfeat, efeat
        )

297
298
299
300
301
302
303
304
305
        g_enc = self.unsup_global_d(unsup_global_emb)
        l_enc = self.unsup_local_d(unsup_local_emb)

        sup_g_enc = self.sup_d(sup_global_emb)
        unsup_g_enc = self.unsup_d(unsup_global_emb)

        # Calculate loss
        unsup_loss = local_global_loss_(l_enc, g_enc, graph_id)
        con_loss = global_global_loss_(sup_g_enc, unsup_g_enc)
306

307
        return unsup_loss, con_loss