models.py 11.3 KB
Newer Older
1
import torch
2
import torch.nn as nn
3
4

import dgl.nn.pytorch as dglnn
5
from dgl import function as fn
6
from dgl.ops import edge_softmax
7
from dgl.utils import expand_as_pair
8

9

10
11
class ElementWiseLinear(nn.Module):
    def __init__(self, size, weight=True, bias=True, inplace=False):
12
        super().__init__()
13
14
15
16
17
18
19
20
21
        if weight:
            self.weight = nn.Parameter(torch.Tensor(size))
        else:
            self.weight = None
        if bias:
            self.bias = nn.Parameter(torch.Tensor(size))
        else:
            self.bias = None
        self.inplace = inplace
22
23
24
25

        self.reset_parameters()

    def reset_parameters(self):
26
27
28
29
        if self.weight is not None:
            nn.init.ones_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
30
31

    def forward(self, x):
32
33
34
35
36
37
38
39
40
41
42
        if self.inplace:
            if self.weight is not None:
                x.mul_(self.weight)
            if self.bias is not None:
                x.add_(self.bias)
        else:
            if self.weight is not None:
                x = x * self.weight
            if self.bias is not None:
                x = x + self.bias
        return x
43
44
45


class GCN(nn.Module):
46
47
48
49
50
51
52
53
54
55
    def __init__(
        self,
        in_feats,
        n_hidden,
        n_classes,
        n_layers,
        activation,
        dropout,
        use_linear,
    ):
56
57
58
59
60
61
62
63
64
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.use_linear = use_linear

        self.convs = nn.ModuleList()
        if use_linear:
            self.linear = nn.ModuleList()
65
        self.norms = nn.ModuleList()
66
67
68
69
70
71

        for i in range(n_layers):
            in_hidden = n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
            bias = i == n_layers - 1

72
73
74
            self.convs.append(
                dglnn.GraphConv(in_hidden, out_hidden, "both", bias=bias)
            )
75
76
77
            if use_linear:
                self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False))
            if i < n_layers - 1:
78
                self.norms.append(nn.BatchNorm1d(out_hidden))
79

80
        self.input_drop = nn.Dropout(min(0.1, dropout))
81
82
83
84
85
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, graph, feat):
        h = feat
86
        h = self.input_drop(h)
87
88
89
90
91
92
93
94
95
96
97

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)

            if self.use_linear:
                linear = self.linear[i](h)
                h = conv + linear
            else:
                h = conv

            if i < self.n_layers - 1:
98
                h = self.norms[i](h)
99
100
101
102
                h = self.activation(h)
                h = self.dropout(h)

        return h
103
104
105
106
107
108
109
110
111
112


class GATConv(nn.Module):
    def __init__(
        self,
        in_feats,
        out_feats,
        num_heads=1,
        feat_drop=0.0,
        attn_drop=0.0,
113
        edge_drop=0.0,
114
        negative_slope=0.2,
115
        use_attn_dst=True,
116
117
118
        residual=False,
        activation=None,
        allow_zero_in_degree=False,
119
        use_symmetric_norm=False,
120
121
122
123
124
125
    ):
        super(GATConv, self).__init__()
        self._num_heads = num_heads
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._allow_zero_in_degree = allow_zero_in_degree
126
        self._use_symmetric_norm = use_symmetric_norm
127
        if isinstance(in_feats, tuple):
128
129
130
131
132
133
            self.fc_src = nn.Linear(
                self._in_src_feats, out_feats * num_heads, bias=False
            )
            self.fc_dst = nn.Linear(
                self._in_dst_feats, out_feats * num_heads, bias=False
            )
134
        else:
135
136
137
138
139
140
            self.fc = nn.Linear(
                self._in_src_feats, out_feats * num_heads, bias=False
            )
        self.attn_l = nn.Parameter(
            torch.FloatTensor(size=(1, num_heads, out_feats))
        )
141
        if use_attn_dst:
142
143
144
            self.attn_r = nn.Parameter(
                torch.FloatTensor(size=(1, num_heads, out_feats))
            )
145
146
        else:
            self.register_buffer("attn_r", None)
147
148
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
149
        self.edge_drop = edge_drop
150
151
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        if residual:
152
153
154
            self.res_fc = nn.Linear(
                self._in_dst_feats, num_heads * out_feats, bias=False
            )
155
156
157
158
159
160
161
162
163
164
165
166
167
        else:
            self.register_buffer("res_fc", None)
        self.reset_parameters()
        self._activation = activation

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        if hasattr(self, "fc"):
            nn.init.xavier_normal_(self.fc.weight, gain=gain)
        else:
            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_l, gain=gain)
168
169
        if isinstance(self.attn_r, nn.Parameter):
            nn.init.xavier_normal_(self.attn_r, gain=gain)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if isinstance(self.res_fc, nn.Linear):
            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)

    def set_allow_zero_in_degree(self, set_value):
        self._allow_zero_in_degree = set_value

    def forward(self, graph, feat):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    assert False

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                if not hasattr(self, "fc_src"):
                    self.fc_src, self.fc_dst = self.fc, self.fc
                feat_src, feat_dst = h_src, h_dst
188
189
190
191
192
193
                feat_src = self.fc_src(h_src).view(
                    -1, self._num_heads, self._out_feats
                )
                feat_dst = self.fc_dst(h_dst).view(
                    -1, self._num_heads, self._out_feats
                )
194
            else:
195
196
                h_src = self.feat_drop(feat)
                feat_src = h_src
197
198
199
                feat_src = self.fc(h_src).view(
                    -1, self._num_heads, self._out_feats
                )
200
                if graph.is_block:
201
                    h_dst = h_src[: graph.number_of_dst_nodes()]
202
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]
203
204
205
                else:
                    h_dst = h_src
                    feat_dst = feat_src
206

207
            if self._use_symmetric_norm:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                degs = graph.out_degrees().float().clamp(min=1)
                norm = torch.pow(degs, -0.5)
                shp = norm.shape + (1,) * (feat_src.dim() - 1)
                norm = torch.reshape(norm, shp)
                feat_src = feat_src * norm

            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({"ft": feat_src, "el": el})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
227
228
229
230
231
232
            if self.attn_r is not None:
                er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
                graph.dstdata.update({"er": er})
                graph.apply_edges(fn.u_add_v("el", "er", "e"))
            else:
                graph.apply_edges(fn.copy_u("el", "e"))
233
            e = self.leaky_relu(graph.edata.pop("e"))
234
235
236
237
238
239

            if self.training and self.edge_drop > 0:
                perm = torch.randperm(graph.number_of_edges(), device=e.device)
                bound = int(graph.number_of_edges() * self.edge_drop)
                eids = perm[bound:]
                graph.edata["a"] = torch.zeros_like(e)
240
241
242
                graph.edata["a"][eids] = self.attn_drop(
                    edge_softmax(graph, e[eids], eids=eids)
                )
243
244
245
            else:
                graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))

246
247
248
249
            # message passing
            graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
            rst = graph.dstdata["ft"]

250
            if self._use_symmetric_norm:
251
252
253
254
255
256
257
258
                degs = graph.in_degrees().float().clamp(min=1)
                norm = torch.pow(degs, 0.5)
                shp = norm.shape + (1,) * (feat_dst.dim() - 1)
                norm = torch.reshape(norm, shp)
                rst = rst * norm

            # residual
            if self.res_fc is not None:
259
260
261
                resval = self.res_fc(h_dst).view(
                    h_dst.shape[0], -1, self._out_feats
                )
262
                rst = rst + resval
263

264
265
266
            # activation
            if self._activation is not None:
                rst = self._activation(rst)
267

268
269
270
271
272
            return rst


class GAT(nn.Module):
    def __init__(
273
274
275
276
277
278
279
280
281
282
283
284
285
        self,
        in_feats,
        n_classes,
        n_hidden,
        n_layers,
        n_heads,
        activation,
        dropout=0.0,
        input_drop=0.0,
        attn_drop=0.0,
        edge_drop=0.0,
        use_attn_dst=True,
        use_symmetric_norm=False,
286
287
288
289
290
291
292
293
294
    ):
        super().__init__()
        self.in_feats = in_feats
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.n_layers = n_layers
        self.num_heads = n_heads

        self.convs = nn.ModuleList()
295
        self.norms = nn.ModuleList()
296
297
298
299

        for i in range(n_layers):
            in_hidden = n_heads * n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
300
            num_heads = n_heads if i < n_layers - 1 else 1
301
302
            out_channels = n_heads

303
304
305
306
307
308
309
310
311
312
313
314
            self.convs.append(
                GATConv(
                    in_hidden,
                    out_hidden,
                    num_heads=num_heads,
                    attn_drop=attn_drop,
                    edge_drop=edge_drop,
                    use_attn_dst=use_attn_dst,
                    use_symmetric_norm=use_symmetric_norm,
                    residual=True,
                )
            )
315
316

            if i < n_layers - 1:
317
                self.norms.append(nn.BatchNorm1d(out_channels * out_hidden))
318

319
320
321
        self.bias_last = ElementWiseLinear(
            n_classes, weight=False, bias=True, inplace=True
        )
322

323
        self.input_drop = nn.Dropout(input_drop)
324
325
326
327
328
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, graph, feat):
        h = feat
329
        h = self.input_drop(h)
330
331
332
333

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)

334
            h = conv
335
336
337

            if i < self.n_layers - 1:
                h = h.flatten(1)
338
339
                h = self.norms[i](h)
                h = self.activation(h, inplace=True)
340
341
342
343
344
345
                h = self.dropout(h)

        h = h.mean(1)
        h = self.bias_last(h)

        return h