layers.py 9.33 KB
Newer Older
1
2
from typing import Optional

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl

5
6
7
8
import torch
import torch.nn
from dgl import DGLGraph
from dgl.nn import GraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from torch import Tensor
10
11
12
13
14
15


class GraphConvWithDropout(GraphConv):
    """
    A GraphConv followed by a Dropout.
    """
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    def __init__(
        self,
        in_feats,
        out_feats,
        dropout=0.3,
        norm="both",
        weight=True,
        bias=True,
        activation=None,
        allow_zero_in_degree=False,
    ):
        super(GraphConvWithDropout, self).__init__(
            in_feats,
            out_feats,
            norm,
            weight,
            bias,
            activation,
            allow_zero_in_degree,
        )
37
38
39
40
41
42
43
44
45
46
47
48
        self.dropout = torch.nn.Dropout(p=dropout)

    def call(self, graph, feat, weight=None):
        feat = self.dropout(feat)
        return super(GraphConvWithDropout, self).call(graph, feat, weight)


class Discriminator(torch.nn.Module):
    """
    Description
    -----------
    A discriminator used to let the network to discrimate
49
    between positive (neighborhood of center node) and
50
51
52
53
54
55
56
    negative (any neighborhood in graph) samplings.

    Parameters
    ----------
    feat_dim : int
        The number of channels of node features.
    """
57
58

    def __init__(self, feat_dim: int):
59
60
61
        super(Discriminator, self).__init__()
        self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1)
        self.reset_parameters()
62

63
64
65
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.affine.weight)
        torch.nn.init.zeros_(self.affine.bias)
66
67
68
69
70
71
72
73
74

    def forward(
        self,
        h_x: Tensor,
        h_pos: Tensor,
        h_neg: Tensor,
        bias_pos: Optional[Tensor] = None,
        bias_neg: Optional[Tensor] = None,
    ):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        """
        Parameters
        ----------
        h_x : torch.Tensor
            Node features, shape: :obj:`(num_nodes, feat_dim)`
        h_pos : torch.Tensor
            The node features of positive samples
            It has the same shape as :obj:`h_x`
        h_neg : torch.Tensor
            The node features of negative samples
            It has the same shape as :obj:`h_x`
        bias_pos : torch.Tensor
            Bias parameter vector for positive scores
            shape: :obj:`(num_nodes)`
        bias_neg : torch.Tensor
            Bias parameter vector for negative scores
            shape: :obj:`(num_nodes)`

        Returns
        -------
        (torch.Tensor, torch.Tensor)
            The output scores with shape (2 * num_nodes,), (num_nodes,)
        """
        score_pos = self.affine(h_pos, h_x).squeeze()
        score_neg = self.affine(h_neg, h_x).squeeze()
        if bias_pos is not None:
            score_pos = score_pos + bias_pos
        if bias_neg is not None:
            score_neg = score_neg + bias_neg
104

105
        logits = torch.cat((score_pos, score_neg), 0)
106

107
108
109
110
111
112
113
114
115
        return logits, score_pos


class DenseLayer(torch.nn.Module):
    """
    Description
    -----------
    Dense layer with a linear layer and an activation function
    """
116
117
118
119

    def __init__(
        self, in_dim: int, out_dim: int, act: str = "prelu", bias=True
    ):
120
121
122
123
        super(DenseLayer, self).__init__()
        self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias)
        self.act_type = act.lower()
        self.reset_parameters()
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        if self.lin.bias is not None:
            torch.nn.init.zeros_(self.lin.bias)
        if self.act_type == "prelu":
            self.act = torch.nn.PReLU()
        else:
            self.act = torch.relu

    def forward(self, x):
        x = self.lin(x)
        return self.act(x)


class IndexSelect(torch.nn.Module):
    """
    Description
    -----------
    The index selection layer used by VIPool

    Parameters
    ----------
    pool_ratio : float
148
        The pooling ratio (for keeping nodes). For example,
149
150
151
152
153
154
155
156
157
        if `pool_ratio=0.8`, 80\% nodes will be preserved.
    hidden_dim : int
        The number of channels in node features.
    act : str, optional
        The activation function type.
        Default: :obj:`'prelu'`
    dist : int, optional
        DO NOT USE THIS PARAMETER
    """
158
159
160
161
162
163
164
165

    def __init__(
        self,
        pool_ratio: float,
        hidden_dim: int,
        act: str = "prelu",
        dist: int = 1,
    ):
166
167
168
169
170
171
172
        super(IndexSelect, self).__init__()
        self.pool_ratio = pool_ratio
        self.dist = dist
        self.dense = DenseLayer(hidden_dim, hidden_dim, act)
        self.discriminator = Discriminator(hidden_dim)
        self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim)

173
174
175
176
177
178
179
180
    def forward(
        self,
        graph: DGLGraph,
        h_pos: Tensor,
        h_neg: Tensor,
        bias_pos: Optional[Tensor] = None,
        bias_neg: Optional[Tensor] = None,
    ):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        """
        Description
        -----------
        Perform index selection

        Parameters
        ----------
        graph : dgl.DGLGraph
            Input graph.
        h_pos : torch.Tensor
            The node features of positive samples
            It has the same shape as :obj:`h_x`
        h_neg : torch.Tensor
            The node features of negative samples
            It has the same shape as :obj:`h_x`
        bias_pos : torch.Tensor
            Bias parameter vector for positive scores
            shape: :obj:`(num_nodes)`
        bias_neg : torch.Tensor
            Bias parameter vector for negative scores
            shape: :obj:`(num_nodes)`
        """
        # compute scores
        h_pos = self.dense(h_pos)
        h_neg = self.dense(h_neg)
        embed = self.gcn(graph, h_pos)
        h_center = torch.sigmoid(embed)

209
210
211
        logit, logit_pos = self.discriminator(
            h_center, h_pos, h_neg, bias_pos, bias_neg
        )
212
        scores = torch.sigmoid(logit_pos)
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        # sort scores
        scores, idx = torch.sort(scores, descending=True)

        # select top-k
        num_nodes = graph.num_nodes()
        num_select_nodes = int(self.pool_ratio * num_nodes)
        size_list = [num_select_nodes, num_nodes - num_select_nodes]
        select_scores, _ = torch.split(scores, size_list, dim=0)
        select_idx, non_select_idx = torch.split(idx, size_list, dim=0)

        return logit, select_scores, select_idx, non_select_idx, embed


class GraphPool(torch.nn.Module):
    """
    Description
    -----------
    The pooling module for graph

    Parameters
    ----------
    hidden_dim : int
        The number of channels of node features.
    use_gcn : bool, optional
        Whether use gcn in down sampling process.
        default: :obj:`False`
    """
241
242

    def __init__(self, hidden_dim: int, use_gcn=False):
243
244
        super(GraphPool, self).__init__()
        self.use_gcn = use_gcn
245
246
247
248
249
250
251
252
253
254
255
256
257
        self.down_sample_gcn = (
            GraphConvWithDropout(hidden_dim, hidden_dim) if use_gcn else None
        )

    def forward(
        self,
        graph: DGLGraph,
        feat: Tensor,
        select_idx: Tensor,
        non_select_idx: Optional[Tensor] = None,
        scores: Optional[Tensor] = None,
        pool_graph=False,
    ):
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        """
        Description
        -----------
        Perform graph pooling.

        Parameters
        ----------
        graph : dgl.DGLGraph
            The input graph
        feat : torch.Tensor
            The input node feature
        select_idx : torch.Tensor
            The index in fine graph of node from
            coarse graph, this is obtained from
272
            previous graph pooling layers.
273
274
275
276
277
278
279
280
281
282
283
284
        non_select_idx : torch.Tensor, optional
            The index that not included in output graph.
            default: :obj:`None`
        scores : torch.Tensor, optional
            Scores for nodes used for pooling and scaling.
            default: :obj:`None`
        pool_graph : bool, optional
            Whether perform graph pooling on graph topology.
            default: :obj:`False`
        """
        if self.use_gcn:
            feat = self.down_sample_gcn(graph, feat)
285

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        feat = feat[select_idx]
        if scores is not None:
            feat = feat * scores.unsqueeze(-1)

        if pool_graph:
            num_node_batch = graph.batch_num_nodes()
            graph = dgl.node_subgraph(graph, select_idx)
            graph.set_batch_num_nodes(num_node_batch)
            return feat, graph
        else:
            return feat


class GraphUnpool(torch.nn.Module):
    """
    Description
    -----------
    The unpooling module for graph

    Parameters
    ----------
    hidden_dim : int
        The number of channels of node features.
    """
310
311

    def __init__(self, hidden_dim: int):
312
313
        super(GraphUnpool, self).__init__()
        self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
314
315

    def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        """
        Description
        -----------
        Perform graph unpooling

        Parameters
        ----------
        graph : dgl.DGLGraph
            The input graph
        feat : torch.Tensor
            The input node feature
        select_idx : torch.Tensor
            The index in fine graph of node from
            coarse graph, this is obtained from
            previous graph pooling layers.
        """
332
333
334
        fine_feat = torch.zeros(
            (graph.num_nodes(), feat.size(-1)), device=feat.device
        )
335
336
337
        fine_feat[select_idx] = feat
        fine_feat = self.up_sample_gcn(graph, fine_feat)
        return fine_feat