model.py 8.7 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.parameter import Parameter

6
7
from dgl.nn.pytorch import GraphConv

8
9
10
11
12
13
14
15
16

class MatGRUCell(torch.nn.Module):
    """
    GRU cell for matrix, similar to the official code.
    Please refer to section 3.4 of the paper for the formula.
    """

    def __init__(self, in_feats, out_feats):
        super().__init__()
17
        self.update = MatGRUGate(in_feats, out_feats, torch.nn.Sigmoid())
18

19
        self.reset = MatGRUGate(in_feats, out_feats, torch.nn.Sigmoid())
20

21
        self.htilda = MatGRUGate(in_feats, out_feats, torch.nn.Tanh())
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    def forward(self, prev_Q, z_topk=None):
        if z_topk is None:
            z_topk = prev_Q

        update = self.update(z_topk, prev_Q)
        reset = self.reset(z_topk, prev_Q)

        h_cap = reset * prev_Q
        h_cap = self.htilda(z_topk, h_cap)

        new_Q = (1 - update) * prev_Q + update * h_cap

        return new_Q


class MatGRUGate(torch.nn.Module):
    """
    GRU gate for matrix, similar to the official code.
    Please refer to section 3.4 of the paper for the formula.
    """

    def __init__(self, rows, cols, activation):
        super().__init__()
        self.activation = activation
        self.W = Parameter(torch.Tensor(rows, rows))
        self.U = Parameter(torch.Tensor(rows, rows))
        self.bias = Parameter(torch.Tensor(rows, cols))
        self.reset_parameters()

    def reset_parameters(self):
        init.xavier_uniform_(self.W)
        init.xavier_uniform_(self.U)
        init.zeros_(self.bias)

    def forward(self, x, hidden):
58
59
60
        out = self.activation(
            self.W.matmul(x) + self.U.matmul(hidden) + self.bias
        )
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        return out


class TopK(torch.nn.Module):
    """
    Similar to the official `egcn_h.py`. We only consider the node in a timestamp based subgraph,
    so we need to pay attention to `K` should be less than the min node numbers in all subgraph.
    Please refer to section 3.4 of the paper for the formula.
    """

    def __init__(self, feats, k):
        super().__init__()
        self.scorer = Parameter(torch.Tensor(feats, 1))
        self.reset_parameters()

        self.k = k

    def reset_parameters(self):
        init.xavier_uniform_(self.scorer)

    def forward(self, node_embs):
83
84
85
        scores = node_embs.matmul(self.scorer) / self.scorer.norm().clamp(
            min=1e-6
        )
86
        vals, topk_indices = scores.view(-1).topk(self.k)
87
88
89
        out = node_embs[topk_indices] * torch.tanh(
            scores[topk_indices].view(-1, 1)
        )
90
91
92
93
94
        # we need to transpose the output
        return out.t()


class EvolveGCNH(nn.Module):
95
96
97
98
99
100
101
102
    def __init__(
        self,
        in_feats=166,
        n_hidden=76,
        num_layers=2,
        n_classes=2,
        classifier_hidden=510,
    ):
103
104
105
106
107
108
109
110
111
112
        # default parameters follow the official config
        super(EvolveGCNH, self).__init__()
        self.num_layers = num_layers
        self.pooling_layers = nn.ModuleList()
        self.recurrent_layers = nn.ModuleList()
        self.gnn_convs = nn.ModuleList()
        self.gcn_weights_list = nn.ParameterList()

        self.pooling_layers.append(TopK(in_feats, n_hidden))
        # similar to EvolveGCNO
113
114
115
116
117
118
        self.recurrent_layers.append(
            MatGRUCell(in_feats=in_feats, out_feats=n_hidden)
        )
        self.gcn_weights_list.append(
            Parameter(torch.Tensor(in_feats, n_hidden))
        )
119
        self.gnn_convs.append(
120
121
122
123
124
125
126
127
            GraphConv(
                in_feats=in_feats,
                out_feats=n_hidden,
                bias=False,
                activation=nn.RReLU(),
                weight=False,
            )
        )
128
129
        for _ in range(num_layers - 1):
            self.pooling_layers.append(TopK(n_hidden, n_hidden))
130
131
132
133
134
135
            self.recurrent_layers.append(
                MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)
            )
            self.gcn_weights_list.append(
                Parameter(torch.Tensor(n_hidden, n_hidden))
            )
136
            self.gnn_convs.append(
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                GraphConv(
                    in_feats=n_hidden,
                    out_feats=n_hidden,
                    bias=False,
                    activation=nn.RReLU(),
                    weight=False,
                )
            )

        self.mlp = nn.Sequential(
            nn.Linear(n_hidden, classifier_hidden),
            nn.ReLU(),
            nn.Linear(classifier_hidden, n_classes),
        )
151
152
153
154
155
156
157
158
159
        self.reset_parameters()

    def reset_parameters(self):
        for gcn_weight in self.gcn_weights_list:
            init.xavier_uniform_(gcn_weight)

    def forward(self, g_list):
        feature_list = []
        for g in g_list:
160
            feature_list.append(g.ndata["feat"])
161
162
163
164
165
        for i in range(self.num_layers):
            W = self.gcn_weights_list[i]
            for j, g in enumerate(g_list):
                X_tilde = self.pooling_layers[i](feature_list[j])
                W = self.recurrent_layers[i](W, X_tilde)
166
167
168
                feature_list[j] = self.gnn_convs[i](
                    g, feature_list[j], weight=W
                )
169
170
171
172
        return self.mlp(feature_list[-1])


class EvolveGCNO(nn.Module):
173
174
175
176
177
178
179
180
    def __init__(
        self,
        in_feats=166,
        n_hidden=256,
        num_layers=2,
        n_classes=2,
        classifier_hidden=307,
    ):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # default parameters follow the official config
        super(EvolveGCNO, self).__init__()
        self.num_layers = num_layers
        self.recurrent_layers = nn.ModuleList()
        self.gnn_convs = nn.ModuleList()
        self.gcn_weights_list = nn.ParameterList()

        # In the paper, EvolveGCN-O use LSTM as RNN layer. According to the official code,
        # EvolveGCN-O use GRU as RNN layer. Here we follow the official code.
        # See: https://github.com/IBM/EvolveGCN/blob/90869062bbc98d56935e3d92e1d9b1b4c25be593/egcn_o.py#L53
        # PS: I try to use torch.nn.LSTM directly,
        #     like [pyg_temporal](github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/evolvegcno.py)
        #     but the performance is worse than use torch.nn.GRU.
        # PPS: I think torch.nn.GRU can't match the manually implemented GRU cell in the official code,
        #      we follow the official code here.
196
197
198
199
200
201
        self.recurrent_layers.append(
            MatGRUCell(in_feats=in_feats, out_feats=n_hidden)
        )
        self.gcn_weights_list.append(
            Parameter(torch.Tensor(in_feats, n_hidden))
        )
202
        self.gnn_convs.append(
203
204
205
206
207
208
209
210
            GraphConv(
                in_feats=in_feats,
                out_feats=n_hidden,
                bias=False,
                activation=nn.RReLU(),
                weight=False,
            )
        )
211
        for _ in range(num_layers - 1):
212
213
214
215
216
217
            self.recurrent_layers.append(
                MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)
            )
            self.gcn_weights_list.append(
                Parameter(torch.Tensor(n_hidden, n_hidden))
            )
218
            self.gnn_convs.append(
219
220
221
222
223
224
225
226
227
228
229
230
231
232
                GraphConv(
                    in_feats=n_hidden,
                    out_feats=n_hidden,
                    bias=False,
                    activation=nn.RReLU(),
                    weight=False,
                )
            )

        self.mlp = nn.Sequential(
            nn.Linear(n_hidden, classifier_hidden),
            nn.ReLU(),
            nn.Linear(classifier_hidden, n_classes),
        )
233
234
235
236
237
238
239
240
241
        self.reset_parameters()

    def reset_parameters(self):
        for gcn_weight in self.gcn_weights_list:
            init.xavier_uniform_(gcn_weight)

    def forward(self, g_list):
        feature_list = []
        for g in g_list:
242
            feature_list.append(g.ndata["feat"])
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        for i in range(self.num_layers):
            W = self.gcn_weights_list[i]
            for j, g in enumerate(g_list):
                # Attention: I try to use the below code to set gcn.weight(similar to pyG_temporal),
                # but it doesn't work. It seems that the gradient function lost in this situation,
                # more discussion see here: https://github.com/benedekrozemberczki/pytorch_geometric_temporal/issues/80
                # ====================================================
                # W = self.gnn_convs[i].weight[None, :, :]
                # W, _ = self.recurrent_layers[i](W)
                # self.gnn_convs[i].weight = nn.Parameter(W.squeeze())
                # ====================================================

                # Remove the following line of code, it will become `GCN`.
                W = self.recurrent_layers[i](W)
257
258
259
                feature_list[j] = self.gnn_convs[i](
                    g, feature_list[j], weight=W
                )
260
        return self.mlp(feature_list[-1])