layers.py 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import dgl
import dgl.function as fn
import scipy.sparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn import AvgPooling, GraphConv, MaxPooling
from dgl.ops import edge_softmax

from functions import edge_sparsemax
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
13
from torch import Tensor
from torch.nn import Parameter
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from utils import get_batch_id, topk


class WeightedGraphConv(GraphConv):
    r"""
    Description
    -----------
    GraphConv with edge weights on homogeneous graphs.
    If edge weights are not given, directly call GraphConv instead.

    Parameters
    ----------
    graph : DGLGraph
        The graph to perform this operation.
    n_feat : torch.Tensor
        The node features
    e_feat : torch.Tensor, optional
        The edge features. Default: :obj:`None`
    """
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
33
34

    def forward(self, graph: DGLGraph, n_feat, e_feat=None):
35
36
        if e_feat is None:
            return super(WeightedGraphConv, self).forward(graph, n_feat)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37

38
39
40
41
42
43
44
45
46
47
        with graph.local_scope():
            if self.weight is not None:
                n_feat = torch.matmul(n_feat, self.weight)
            src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
            src_norm = src_norm.view(-1, 1)
            dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
            dst_norm = dst_norm.view(-1, 1)
            n_feat = n_feat * src_norm
            graph.ndata["h"] = n_feat
            graph.edata["e"] = e_feat
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
48
            graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
            n_feat = graph.ndata.pop("h")
            n_feat = n_feat * dst_norm
            if self.bias is not None:
                n_feat = n_feat + self.bias
            if self._activation is not None:
                n_feat = self._activation(n_feat)
            return n_feat


class NodeInfoScoreLayer(nn.Module):
    r"""
    Description
    -----------
    Compute a score for each node for sort-pooling. The score of each node
    is computed via the absolute difference of its first-order random walk
    result and its features.

    Arguments
    ---------
    sym_norm : bool, optional
        If true, use symmetric norm for adjacency.
        Default: :obj:`True`

    Parameters
    ----------
    graph : DGLGraph
        The graph to perform this operation.
    feat : torch.Tensor
        The node features
    e_feat : torch.Tensor, optional
        The edge features. Default: :obj:`None`
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
80

81
82
83
84
85
    Returns
    -------
    Tensor
        Score for each node.
    """
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
86
87

    def __init__(self, sym_norm: bool = True):
88
89
90
        super(NodeInfoScoreLayer, self).__init__()
        self.sym_norm = sym_norm

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
91
    def forward(self, graph: dgl.DGLGraph, feat: Tensor, e_feat: Tensor):
92
93
        with graph.local_scope():
            if self.sym_norm:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
94
95
96
                src_norm = torch.pow(
                    graph.out_degrees().float().clamp(min=1), -0.5
                )
97
                src_norm = src_norm.view(-1, 1).to(feat.device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
98
99
100
                dst_norm = torch.pow(
                    graph.in_degrees().float().clamp(min=1), -0.5
                )
101
102
103
                dst_norm = dst_norm.view(-1, 1).to(feat.device)

                src_feat = feat * src_norm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104

105
106
107
                graph.ndata["h"] = src_feat
                graph.edata["e"] = e_feat
                graph = dgl.remove_self_loop(graph)
108
                graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
109

110
111
112
                dst_feat = graph.ndata.pop("h") * dst_norm
                feat = feat - dst_feat
            else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
113
                dst_norm = 1.0 / graph.in_degrees().float().clamp(min=1)
114
115
116
117
118
                dst_norm = dst_norm.view(-1, 1)

                graph.ndata["h"] = feat
                graph.edata["e"] = e_feat
                graph = dgl.remove_self_loop(graph)
119
                graph.update_all(fn.u_mul_e("h", "e", "m"), fn.sum("m", "h"))
120
121
122
123
124
125
126
127
128
129
130
131

                feat = feat - dst_norm * graph.ndata.pop("h")

            score = torch.sum(torch.abs(feat), dim=1)
            return score


class HGPSLPool(nn.Module):
    r"""

    Description
    -----------
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
132
    The HGP-SL pooling layer from
133
134
135
136
137
138
139
140
141
    `Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`

    Parameters
    ----------
    in_feat : int
        The number of input node feature's channels
    ratio : float, optional
        Pooling ratio. Default: 0.8
    sample : bool, optional
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
142
        Whether use k-hop union graph to increase efficiency.
143
144
145
146
147
148
149
150
151
152
153
154
        Currently we only support full graph. Default: :obj:`False`
    sym_score_norm : bool, optional
        Use symmetric norm for adjacency or not. Default: :obj:`True`
    sparse : bool, optional
        Use edge sparsemax instead of edge softmax. Default: :obj:`True`
    sl : bool, optional
        Use structure learining module or not. Default: :obj:`True`
    lamb : float, optional
        The lambda parameter as weight of raw adjacency as described in the
        HGP-SL paper. Default: 1.0
    negative_slop : float, optional
        Negative slop for leaky_relu. Default: 0.2
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
155

156
157
158
159
160
161
162
163
164
165
166
    Returns
    -------
    DGLGraph
        The pooled graph.
    torch.Tensor
        Node features
    torch.Tensor
        Edge features
    torch.Tensor
        Permutation index
    """
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
167
168
169
170
171
172
173
174
175
176
177
178
179

    def __init__(
        self,
        in_feat: int,
        ratio=0.8,
        sample=True,
        sym_score_norm=True,
        sparse=True,
        sl=True,
        lamb=1.0,
        negative_slop=0.2,
        k_hop=3,
    ):
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        super(HGPSLPool, self).__init__()
        self.in_feat = in_feat
        self.ratio = ratio
        self.sample = sample
        self.sparse = sparse
        self.sl = sl
        self.lamb = lamb
        self.negative_slop = negative_slop
        self.k_hop = k_hop

        self.att = Parameter(torch.Tensor(1, self.in_feat * 2))
        self.calc_info_score = NodeInfoScoreLayer(sym_norm=sym_score_norm)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.att.data)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
198
    def forward(self, graph: DGLGraph, feat: Tensor, e_feat=None):
199
200
        # top-k pool first
        if e_feat is None:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
201
202
203
            e_feat = torch.ones(
                (graph.number_of_edges(),), dtype=feat.dtype, device=feat.device
            )
204
205
        batch_num_nodes = graph.batch_num_nodes()
        x_score = self.calc_info_score(graph, feat, e_feat)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
206
207
208
        perm, next_batch_num_nodes = topk(
            x_score, self.ratio, get_batch_id(batch_num_nodes), batch_num_nodes
        )
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        feat = feat[perm]
        pool_graph = None
        if not self.sample or not self.sl:
            # pool graph
            graph.edata["e"] = e_feat
            pool_graph = dgl.node_subgraph(graph, perm)
            e_feat = pool_graph.edata.pop("e")
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)

        # no structure learning layer, directly return.
        if not self.sl:
            return pool_graph, feat, e_feat, perm

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each
            # pair of nodes is time consuming. To accelerate this process,
            # we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
229

230
231
232
233
            # first build multi-hop graph
            row, col = graph.all_edges()
            num_nodes = graph.num_nodes()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
234
235
236
237
238
239
240
            scipy_adj = scipy.sparse.coo_matrix(
                (
                    e_feat.detach().cpu(),
                    (row.detach().cpu(), col.detach().cpu()),
                ),
                shape=(num_nodes, num_nodes),
            )
241
            for _ in range(self.k_hop):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
242
                two_hop = scipy_adj**2
243
244
245
246
247
                two_hop = two_hop * (1e-5 / two_hop.max())
                scipy_adj = two_hop + scipy_adj
            row, col = scipy_adj.nonzero()
            row = torch.tensor(row, dtype=torch.long, device=graph.device)
            col = torch.tensor(col, dtype=torch.long, device=graph.device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
248
249
250
            e_feat = torch.tensor(
                scipy_adj.data, dtype=torch.float, device=feat.device
            )
251
252

            # perform pooling on multi-hop graph
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
253
            mask = perm.new_full((num_nodes,), -1)
254
255
256
            i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
            mask[perm] = i
            row, col = mask[row], mask[col]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
257
            mask = (row >= 0) & (col >= 0)
258
259
260
261
262
            row, col = row[mask], col[mask]
            e_feat = e_feat[mask]

            # add remaining self loops
            mask = row != col
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
263
264
265
266
            num_nodes = perm.size(0)  # num nodes after pool
            loop_index = torch.arange(
                0, num_nodes, dtype=row.dtype, device=row.device
            )
267
            inv_mask = ~mask
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
268
269
270
            loop_weight = torch.full(
                (num_nodes,), 0, dtype=e_feat.dtype, device=e_feat.device
            )
271
272
273
274
275
276
277
278
279
            remaining_e_feat = e_feat[inv_mask]
            if remaining_e_feat.numel() > 0:
                loop_weight[row[inv_mask]] = remaining_e_feat
            e_feat = torch.cat([e_feat[mask], loop_weight], dim=0)
            row, col = row[mask], col[mask]
            row = torch.cat([row, loop_index], dim=0)
            col = torch.cat([col, loop_index], dim=0)

            # attention scores
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
280
281
282
283
284
285
286
            weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(
                dim=-1
            )
            weights = (
                F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
            )

287
288
289
290
291
292
            # sl and normalization
            sl_graph = dgl.graph((row, col))
            if self.sparse:
                weights = edge_sparsemax(sl_graph, weights)
            else:
                weights = edge_softmax(sl_graph, weights)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
293

294
295
296
297
298
299
300
            # get final graph
            mask = torch.abs(weights) > 0
            row, col, weights = row[mask], col[mask], weights[mask]
            pool_graph = dgl.graph((row, col))
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)
            e_feat = weights

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
301
        else:
302
303
304
305
306
307
308
            # Learning the possible edge weights between each pair of
            # nodes in the pooled subgraph, relative slower.

            # construct complete graphs for all graph in the batch
            # use dense to build, then transform to sparse.
            # maybe there's more efficient way?
            batch_num_nodes = next_batch_num_nodes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
309
310
311
312
313
314
315
            block_begin_idx = torch.cat(
                [
                    batch_num_nodes.new_zeros(1),
                    batch_num_nodes.cumsum(dim=0)[:-1],
                ],
                dim=0,
            )
316
            block_end_idx = batch_num_nodes.cumsum(dim=0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
317
318
319
320
321
            dense_adj = torch.zeros(
                (pool_graph.num_nodes(), pool_graph.num_nodes()),
                dtype=torch.float,
                device=feat.device,
            )
322
            for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
323
                dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.0
324
325
326
            row, col = torch.nonzero(dense_adj).t().contiguous()

            # compute weights for node-pairs
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
327
328
329
            weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(
                dim=-1
            )
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            weights = F.leaky_relu(weights, self.negative_slop)
            dense_adj[row, col] = weights

            # add pooled graph structure to weight matrix
            pool_row, pool_col = pool_graph.all_edges()
            dense_adj[pool_row, pool_col] += self.lamb * e_feat
            weights = dense_adj[row, col]
            del dense_adj
            torch.cuda.empty_cache()

            # edge softmax/sparsemax
            complete_graph = dgl.graph((row, col))
            if self.sparse:
                weights = edge_sparsemax(complete_graph, weights)
            else:
                weights = edge_softmax(complete_graph, weights)

            # get new e_feat and graph structure, clean up.
            mask = torch.abs(weights) > 1e-9
            row, col, weights = row[mask], col[mask], weights[mask]
            e_feat = weights
            pool_graph = dgl.graph((row, col))
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)

        return pool_graph, feat, e_feat, perm


class ConvPoolReadout(torch.nn.Module):
    """A helper class. (GraphConv -> Pooling -> Readout)"""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
359
360
361
362
363
364
365
366
367
368
369
370

    def __init__(
        self,
        in_feat: int,
        out_feat: int,
        pool_ratio=0.8,
        sample: bool = False,
        sparse: bool = True,
        sl: bool = True,
        lamb: float = 1.0,
        pool: bool = True,
    ):
371
372
373
374
        super(ConvPoolReadout, self).__init__()
        self.use_pool = pool
        self.conv = WeightedGraphConv(in_feat, out_feat)
        if pool:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
375
376
377
378
379
380
381
382
            self.pool = HGPSLPool(
                out_feat,
                ratio=pool_ratio,
                sparse=sparse,
                sample=sample,
                sl=sl,
                lamb=lamb,
            )
383
384
385
386
387
388
389
390
391
        else:
            self.pool = None
        self.avgpool = AvgPooling()
        self.maxpool = MaxPooling()

    def forward(self, graph, feature, e_feat=None):
        out = F.relu(self.conv(graph, feature, e_feat))
        if self.use_pool:
            graph, out, e_feat, _ = self.pool(graph, out, e_feat)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
392
393
394
        readout = torch.cat(
            [self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1
        )
395
        return graph, out, e_feat, readout