modules.py 4.53 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
2
import torch as th
lt610's avatar
lt610 committed
3
4
import torch.nn as nn
import torch.nn.functional as F
5

lt610's avatar
lt610 committed
6
7

class GCNLayer(nn.Module):
8
9
10
11
12
13
14
15
16
17
    def __init__(
        self,
        in_dim,
        out_dim,
        order=1,
        act=None,
        dropout=0,
        batch_norm=False,
        aggr="concat",
    ):
lt610's avatar
lt610 committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        super(GCNLayer, self).__init__()
        self.lins = nn.ModuleList()
        self.bias = nn.ParameterList()
        for _ in range(order + 1):
            self.lins.append(nn.Linear(in_dim, out_dim, bias=False))
            self.bias.append(nn.Parameter(th.zeros(out_dim)))

        self.order = order
        self.act = act
        self.dropout = nn.Dropout(dropout)

        self.batch_norm = batch_norm
        if batch_norm:
            self.offset, self.scale = nn.ParameterList(), nn.ParameterList()
            for _ in range(order + 1):
                self.offset.append(nn.Parameter(th.zeros(out_dim)))
                self.scale.append(nn.Parameter(th.ones(out_dim)))

        self.aggr = aggr
        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.lins:
            nn.init.xavier_normal_(lin.weight)

43
44
45
    def feat_trans(
        self, features, idx
    ):  # linear transformation + activation + batch normalization
lt610's avatar
lt610 committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        h = self.lins[idx](features) + self.bias[idx]

        if self.act is not None:
            h = self.act(h)

        if self.batch_norm:
            mean = h.mean(dim=1).view(h.shape[0], 1)
            var = h.var(dim=1, unbiased=False).view(h.shape[0], 1) + 1e-9
            h = (h - mean) * self.scale[idx] * th.rsqrt(var) + self.offset[idx]

        return h

    def forward(self, graph, features):
        g = graph.local_var()
        h_in = self.dropout(features)
        h_hop = [h_in]

63
64
65
66
67
        D_norm = (
            g.ndata["train_D_norm"]
            if "train_D_norm" in g.ndata
            else g.ndata["full_D_norm"]
        )
K's avatar
K committed
68
        for _ in range(self.order):  # forward propagation
69
70
71
72
73
            g.ndata["h"] = h_hop[-1]
            if "w" not in g.edata:
                g.edata["w"] = th.ones((g.num_edges(),)).to(features.device)
            g.update_all(fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"))
            h = g.ndata.pop("h")
lt610's avatar
lt610 committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            h = h * D_norm
            h_hop.append(h)

        h_part = [self.feat_trans(ft, idx) for idx, ft in enumerate(h_hop)]
        if self.aggr == "mean":
            h_out = h_part[0]
            for i in range(len(h_part) - 1):
                h_out = h_out + h_part[i + 1]
        elif self.aggr == "concat":
            h_out = th.cat(h_part, 1)
        else:
            raise NotImplementedError

        return h_out


class GCNNet(nn.Module):
91
92
93
94
95
96
97
98
99
100
101
    def __init__(
        self,
        in_dim,
        hid_dim,
        out_dim,
        arch="1-1-0",
        act=F.relu,
        dropout=0,
        batch_norm=False,
        aggr="concat",
    ):
lt610's avatar
lt610 committed
102
103
104
        super(GCNNet, self).__init__()
        self.gcn = nn.ModuleList()

105
106
107
108
109
110
111
112
113
114
115
116
        orders = list(map(int, arch.split("-")))
        self.gcn.append(
            GCNLayer(
                in_dim=in_dim,
                out_dim=hid_dim,
                order=orders[0],
                act=act,
                dropout=dropout,
                batch_norm=batch_norm,
                aggr=aggr,
            )
        )
lt610's avatar
lt610 committed
117
118
        pre_out = ((aggr == "concat") * orders[0] + 1) * hid_dim

119
120
121
122
123
124
125
126
127
128
129
130
        for i in range(1, len(orders) - 1):
            self.gcn.append(
                GCNLayer(
                    in_dim=pre_out,
                    out_dim=hid_dim,
                    order=orders[i],
                    act=act,
                    dropout=dropout,
                    batch_norm=batch_norm,
                    aggr=aggr,
                )
            )
lt610's avatar
lt610 committed
131
132
            pre_out = ((aggr == "concat") * orders[i] + 1) * hid_dim

133
134
135
136
137
138
139
140
141
142
143
        self.gcn.append(
            GCNLayer(
                in_dim=pre_out,
                out_dim=hid_dim,
                order=orders[-1],
                act=act,
                dropout=dropout,
                batch_norm=batch_norm,
                aggr=aggr,
            )
        )
lt610's avatar
lt610 committed
144
145
        pre_out = ((aggr == "concat") * orders[-1] + 1) * hid_dim

146
147
148
149
150
151
152
153
154
        self.out_layer = GCNLayer(
            in_dim=pre_out,
            out_dim=out_dim,
            order=0,
            act=None,
            dropout=dropout,
            batch_norm=False,
            aggr=aggr,
        )
lt610's avatar
lt610 committed
155
156

    def forward(self, graph):
157
        h = graph.ndata["feat"]
lt610's avatar
lt610 committed
158
159
160
161
162
163
164
165

        for layer in self.gcn:
            h = layer(graph, h)

        h = F.normalize(h, p=2, dim=1)
        h = self.out_layer(graph, h)

        return h