"script/cuda.sh" did not exist on "136ba78cd3ecded46e032d41356237528fc3d9be"
layers.py 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import dgl
import dgl.function as fn
import scipy.sparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn import AvgPooling, GraphConv, MaxPooling
from dgl.ops import edge_softmax
from torch import Tensor
from torch.nn import Parameter

from functions import edge_sparsemax
from utils import get_batch_id, topk


class WeightedGraphConv(GraphConv):
    r"""
    Description
    -----------
    GraphConv with edge weights on homogeneous graphs.
    If edge weights are not given, directly call GraphConv instead.

    Parameters
    ----------
    graph : DGLGraph
        The graph to perform this operation.
    n_feat : torch.Tensor
        The node features
    e_feat : torch.Tensor, optional
        The edge features. Default: :obj:`None`
    """
    def forward(self, graph:DGLGraph, n_feat, e_feat=None):
        if e_feat is None:
            return super(WeightedGraphConv, self).forward(graph, n_feat)
        
        with graph.local_scope():
            if self.weight is not None:
                n_feat = torch.matmul(n_feat, self.weight)
            src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
            src_norm = src_norm.view(-1, 1)
            dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
            dst_norm = dst_norm.view(-1, 1)
            n_feat = n_feat * src_norm
            graph.ndata["h"] = n_feat
            graph.edata["e"] = e_feat
            graph.update_all(fn.src_mul_edge("h", "e", "m"),
                             fn.sum("m", "h"))
            n_feat = graph.ndata.pop("h")
            n_feat = n_feat * dst_norm
            if self.bias is not None:
                n_feat = n_feat + self.bias
            if self._activation is not None:
                n_feat = self._activation(n_feat)
            return n_feat


class NodeInfoScoreLayer(nn.Module):
    r"""
    Description
    -----------
    Compute a score for each node for sort-pooling. The score of each node
    is computed via the absolute difference of its first-order random walk
    result and its features.

    Arguments
    ---------
    sym_norm : bool, optional
        If true, use symmetric norm for adjacency.
        Default: :obj:`True`

    Parameters
    ----------
    graph : DGLGraph
        The graph to perform this operation.
    feat : torch.Tensor
        The node features
    e_feat : torch.Tensor, optional
        The edge features. Default: :obj:`None`
    
    Returns
    -------
    Tensor
        Score for each node.
    """
    def __init__(self, sym_norm:bool=True):
        super(NodeInfoScoreLayer, self).__init__()
        self.sym_norm = sym_norm

    def forward(self, graph:dgl.DGLGraph, feat:Tensor, e_feat:Tensor):
        with graph.local_scope():
            if self.sym_norm:
                src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
                src_norm = src_norm.view(-1, 1).to(feat.device)
                dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
                dst_norm = dst_norm.view(-1, 1).to(feat.device)

                src_feat = feat * src_norm
                
                graph.ndata["h"] = src_feat
                graph.edata["e"] = e_feat
                graph = dgl.remove_self_loop(graph)
                graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h"))
                
                dst_feat = graph.ndata.pop("h") * dst_norm
                feat = feat - dst_feat
            else:
                dst_norm = 1. / graph.in_degrees().float().clamp(min=1)
                dst_norm = dst_norm.view(-1, 1)

                graph.ndata["h"] = feat
                graph.edata["e"] = e_feat
                graph = dgl.remove_self_loop(graph)
                graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h"))

                feat = feat - dst_norm * graph.ndata.pop("h")

            score = torch.sum(torch.abs(feat), dim=1)
            return score


class HGPSLPool(nn.Module):
    r"""

    Description
    -----------
    The HGP-SL pooling layer from 
    `Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`

    Parameters
    ----------
    in_feat : int
        The number of input node feature's channels
    ratio : float, optional
        Pooling ratio. Default: 0.8
    sample : bool, optional
        Whether use k-hop union graph to increase efficiency. 
        Currently we only support full graph. Default: :obj:`False`
    sym_score_norm : bool, optional
        Use symmetric norm for adjacency or not. Default: :obj:`True`
    sparse : bool, optional
        Use edge sparsemax instead of edge softmax. Default: :obj:`True`
    sl : bool, optional
        Use structure learining module or not. Default: :obj:`True`
    lamb : float, optional
        The lambda parameter as weight of raw adjacency as described in the
        HGP-SL paper. Default: 1.0
    negative_slop : float, optional
        Negative slop for leaky_relu. Default: 0.2
    
    Returns
    -------
    DGLGraph
        The pooled graph.
    torch.Tensor
        Node features
    torch.Tensor
        Edge features
    torch.Tensor
        Permutation index
    """
    def __init__(self, in_feat:int, ratio=0.8, sample=True, 
                 sym_score_norm=True, sparse=True, sl=True,
                 lamb=1.0, negative_slop=0.2, k_hop=3):
        super(HGPSLPool, self).__init__()
        self.in_feat = in_feat
        self.ratio = ratio
        self.sample = sample
        self.sparse = sparse
        self.sl = sl
        self.lamb = lamb
        self.negative_slop = negative_slop
        self.k_hop = k_hop

        self.att = Parameter(torch.Tensor(1, self.in_feat * 2))
        self.calc_info_score = NodeInfoScoreLayer(sym_norm=sym_score_norm)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.att.data)

    def forward(self, graph:DGLGraph, feat:Tensor, e_feat=None):
        # top-k pool first
        if e_feat is None:
            e_feat = torch.ones((graph.number_of_edges(),), 
                                dtype=feat.dtype, device=feat.device)
        batch_num_nodes = graph.batch_num_nodes()
        x_score = self.calc_info_score(graph, feat, e_feat)
        perm, next_batch_num_nodes = topk(x_score, self.ratio, 
                                          get_batch_id(batch_num_nodes),
                                          batch_num_nodes)
        feat = feat[perm]
        pool_graph = None
        if not self.sample or not self.sl:
            # pool graph
            graph.edata["e"] = e_feat
            pool_graph = dgl.node_subgraph(graph, perm)
            e_feat = pool_graph.edata.pop("e")
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)

        # no structure learning layer, directly return.
        if not self.sl:
            return pool_graph, feat, e_feat, perm

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each
            # pair of nodes is time consuming. To accelerate this process,
            # we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            
            # first build multi-hop graph
            row, col = graph.all_edges()
            num_nodes = graph.num_nodes()

            scipy_adj = scipy.sparse.coo_matrix((e_feat.detach().cpu(), (row.detach().cpu(), col.detach().cpu())), shape=(num_nodes, num_nodes))
            for _ in range(self.k_hop):
                two_hop = scipy_adj ** 2
                two_hop = two_hop * (1e-5 / two_hop.max())
                scipy_adj = two_hop + scipy_adj
            row, col = scipy_adj.nonzero()
            row = torch.tensor(row, dtype=torch.long, device=graph.device)
            col = torch.tensor(col, dtype=torch.long, device=graph.device)
            e_feat = torch.tensor(scipy_adj.data, dtype=torch.float, device=feat.device)

            # perform pooling on multi-hop graph
            mask = perm.new_full((num_nodes, ), -1)
            i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
            mask[perm] = i
            row, col = mask[row], mask[col]
            mask = (row >=0 ) & (col >= 0)
            row, col = row[mask], col[mask]
            e_feat = e_feat[mask]

            # add remaining self loops
            mask = row != col
            num_nodes = perm.size(0) # num nodes after pool
            loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
            inv_mask = ~mask
            loop_weight = torch.full((num_nodes, ), 0, dtype=e_feat.dtype, device=e_feat.device)
            remaining_e_feat = e_feat[inv_mask]
            if remaining_e_feat.numel() > 0:
                loop_weight[row[inv_mask]] = remaining_e_feat
            e_feat = torch.cat([e_feat[mask], loop_weight], dim=0)
            row, col = row[mask], col[mask]
            row = torch.cat([row, loop_index], dim=0)
            col = torch.cat([col, loop_index], dim=0)

            # attention scores
            weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
            
            # sl and normalization
            sl_graph = dgl.graph((row, col))
            if self.sparse:
                weights = edge_sparsemax(sl_graph, weights)
            else:
                weights = edge_softmax(sl_graph, weights)
            
            # get final graph
            mask = torch.abs(weights) > 0
            row, col, weights = row[mask], col[mask], weights[mask]
            pool_graph = dgl.graph((row, col))
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)
            e_feat = weights

        else: 
            # Learning the possible edge weights between each pair of
            # nodes in the pooled subgraph, relative slower.

            # construct complete graphs for all graph in the batch
            # use dense to build, then transform to sparse.
            # maybe there's more efficient way?
            batch_num_nodes = next_batch_num_nodes
            block_begin_idx = torch.cat([batch_num_nodes.new_zeros(1), 
                batch_num_nodes.cumsum(dim=0)[:-1]], dim=0)
            block_end_idx = batch_num_nodes.cumsum(dim=0)
            dense_adj = torch.zeros((pool_graph.num_nodes(),
                                    pool_graph.num_nodes()),
                                    dtype=torch.float, 
                                    device=feat.device)
            for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
                dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.
            row, col = torch.nonzero(dense_adj).t().contiguous()

            # compute weights for node-pairs
            weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            dense_adj[row, col] = weights

            # add pooled graph structure to weight matrix
            pool_row, pool_col = pool_graph.all_edges()
            dense_adj[pool_row, pool_col] += self.lamb * e_feat
            weights = dense_adj[row, col]
            del dense_adj
            torch.cuda.empty_cache()

            # edge softmax/sparsemax
            complete_graph = dgl.graph((row, col))
            if self.sparse:
                weights = edge_sparsemax(complete_graph, weights)
            else:
                weights = edge_softmax(complete_graph, weights)

            # get new e_feat and graph structure, clean up.
            mask = torch.abs(weights) > 1e-9
            row, col, weights = row[mask], col[mask], weights[mask]
            e_feat = weights
            pool_graph = dgl.graph((row, col))
            pool_graph.set_batch_num_nodes(next_batch_num_nodes)

        return pool_graph, feat, e_feat, perm


class ConvPoolReadout(torch.nn.Module):
    """A helper class. (GraphConv -> Pooling -> Readout)"""
    def __init__(self, in_feat:int, out_feat:int, pool_ratio=0.8,
                 sample:bool=False, sparse:bool=True, sl:bool=True,
                 lamb:float=1., pool:bool=True):
        super(ConvPoolReadout, self).__init__()
        self.use_pool = pool
        self.conv = WeightedGraphConv(in_feat, out_feat)
        if pool:
            self.pool = HGPSLPool(out_feat, ratio=pool_ratio, sparse=sparse,
                                  sample=sample, sl=sl, lamb=lamb)
        else:
            self.pool = None
        self.avgpool = AvgPooling()
        self.maxpool = MaxPooling()

    def forward(self, graph, feature, e_feat=None):
        out = F.relu(self.conv(graph, feature, e_feat))
        if self.use_pool:
            graph, out, e_feat, _ = self.pool(graph, out, e_feat)
        readout = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
        return graph, out, e_feat, readout