networks.py 3.28 KB
Newer Older
1
2
import torch
import torch.nn
3
import torch.nn.functional as F
4

5
from dgl.nn import AvgPooling, MaxPooling
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
from layers import ConvPoolReadout
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

class HGPSLModel(torch.nn.Module):
    r"""

    Description
    -----------
    The graph classification model using HGP-SL pooling.

    Parameters
    ----------
    in_feat : int
        The number of input node feature's channels.
    out_feat : int
        The number of output node feature's channels.
    hid_feat : int
        The number of hidden state's channels.
    dropout : float, optional
        The dropout rate. Default: 0
    pool_ratio : float, optional
        The pooling ratio for each pooling layer. Default: 0.5
    conv_layers : int, optional
        The number of graph convolution and pooling layers. Default: 3
    sample : bool, optional
31
        Whether use k-hop union graph to increase efficiency.
32
33
34
35
36
37
38
39
40
        Currently we only support full graph. Default: :obj:`False`
    sparse : bool, optional
        Use edge sparsemax instead of edge softmax. Default: :obj:`True`
    sl : bool, optional
        Use structure learining module or not. Default: :obj:`True`
    lamb : float, optional
        The lambda parameter as weight of raw adjacency as described in the
        HGP-SL paper. Default: 1.0
    """
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    def __init__(
        self,
        in_feat: int,
        out_feat: int,
        hid_feat: int,
        dropout: float = 0.0,
        pool_ratio: float = 0.5,
        conv_layers: int = 3,
        sample: bool = False,
        sparse: bool = True,
        sl: bool = True,
        lamb: float = 1.0,
    ):
55
56
57
58
59
60
61
62
63
64
65
66
        super(HGPSLModel, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.hid_feat = hid_feat
        self.dropout = dropout
        self.num_layers = conv_layers
        self.pool_ratio = pool_ratio

        convpools = []
        for i in range(conv_layers):
            c_in = in_feat if i == 0 else hid_feat
            c_out = hid_feat
67
68
69
70
71
72
73
74
75
76
77
78
79
            use_pool = i != conv_layers - 1
            convpools.append(
                ConvPoolReadout(
                    c_in,
                    c_out,
                    pool_ratio=pool_ratio,
                    sample=sample,
                    sparse=sparse,
                    sl=sl,
                    lamb=lamb,
                    pool=use_pool,
                )
            )
80
81
82
83
84
85
86
87
88
89
90
        self.convpool_layers = torch.nn.ModuleList(convpools)

        self.lin1 = torch.nn.Linear(hid_feat * 2, hid_feat)
        self.lin2 = torch.nn.Linear(hid_feat, hid_feat // 2)
        self.lin3 = torch.nn.Linear(hid_feat // 2, self.out_feat)

    def forward(self, graph, n_feat):
        final_readout = None
        e_feat = None

        for i in range(self.num_layers):
91
92
93
            graph, n_feat, e_feat, readout = self.convpool_layers[i](
                graph, n_feat, e_feat
            )
94
95
96
97
            if final_readout is None:
                final_readout = readout
            else:
                final_readout = final_readout + readout
98

99
100
101
102
103
104
105
        n_feat = F.relu(self.lin1(final_readout))
        n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)
        n_feat = F.relu(self.lin2(n_feat))
        n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)
        n_feat = self.lin3(n_feat)

        return F.log_softmax(n_feat, dim=-1)