encoder.py 7.48 KB
Newer Older
1
2
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl

5
import numpy as np
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.linalg import block_diag
10
from torch.nn import init
11

12
from .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer
13
from .model_utils import batch2tensor
14
from .tensorized_layers import *
15

16

17
18
19
20
class DiffPool(nn.Module):
    """
    DiffPool Fuse
    """
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    def __init__(
        self,
        input_dim,
        hidden_dim,
        embedding_dim,
        label_dim,
        activation,
        n_layers,
        dropout,
        n_pooling,
        linkpred,
        batch_size,
        aggregator_type,
        assign_dim,
        pool_ratio,
        cat=False,
    ):
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        super(DiffPool, self).__init__()
        self.link_pred = linkpred
        self.concat = cat
        self.n_pooling = n_pooling
        self.batch_size = batch_size
        self.link_pred_loss = []
        self.entropy_loss = []

        # list of GNN modules before the first diffpool operation
        self.gc_before_pool = nn.ModuleList()
        self.diffpool_layers = nn.ModuleList()

        # list of list of GNN modules, each list after one diffpool operation
        self.gc_after_pool = nn.ModuleList()
        self.assign_dim = assign_dim
        self.bn = True
        self.num_aggs = 1

        # constructing layers
        # layers before diffpool
        assert n_layers >= 3, "n_layers too few"
60
61
62
63
64
65
66
        self.gc_before_pool.append(
            GraphSageLayer(
                input_dim,
                hidden_dim,
                activation,
                dropout,
                aggregator_type,
67
68
69
                self.bn,
            )
        )
70
        for _ in range(n_layers - 2):
71
72
73
74
75
76
77
            self.gc_before_pool.append(
                GraphSageLayer(
                    hidden_dim,
                    hidden_dim,
                    activation,
                    dropout,
                    aggregator_type,
78
79
80
                    self.bn,
                )
            )
81
82
        self.gc_before_pool.append(
            GraphSageLayer(
83
84
85
                hidden_dim, embedding_dim, None, dropout, aggregator_type
            )
        )
86
87
88
89
90
91

        assign_dims = []
        assign_dims.append(self.assign_dim)
        if self.concat:
            # diffpool layer receive pool_emedding_dim node feature tensor
            # and return pool_embedding_dim node embedding
92
            pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim
93
94
        else:
            pool_embedding_dim = embedding_dim
95
96
97
98
99
100
101
102

        self.first_diffpool_layer = DiffPoolBatchedGraphLayer(
            pool_embedding_dim,
            self.assign_dim,
            hidden_dim,
            activation,
            dropout,
            aggregator_type,
103
104
            self.link_pred,
        )
105
        gc_after_per_pool = nn.ModuleList()
106

107
108
109
110
        for _ in range(n_layers - 1):
            gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))
        gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))
        self.gc_after_pool.append(gc_after_per_pool)
111

112
113
        self.assign_dim = int(self.assign_dim * pool_ratio)
        # each pooling module
114
115
116
117
118
119
        for _ in range(n_pooling - 1):
            self.diffpool_layers.append(
                BatchedDiffPool(
                    pool_embedding_dim,
                    self.assign_dim,
                    hidden_dim,
120
121
122
                    self.link_pred,
                )
            )
123
124
            gc_after_per_pool = nn.ModuleList()
            for _ in range(n_layers - 1):
125
                gc_after_per_pool.append(
126
127
                    BatchedGraphSAGE(hidden_dim, hidden_dim)
                )
128
            gc_after_per_pool.append(
129
130
                BatchedGraphSAGE(hidden_dim, embedding_dim)
            )
131
132
133
            self.gc_after_pool.append(gc_after_per_pool)
            assign_dims.append(self.assign_dim)
            self.assign_dim = int(self.assign_dim * pool_ratio)
134

135
136
        # predicting layer
        if self.concat:
137
138
139
            self.pred_input_dim = (
                pool_embedding_dim * self.num_aggs * (n_pooling + 1)
            )
140
        else:
141
            self.pred_input_dim = embedding_dim * self.num_aggs
142
143
144
145
146
        self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
147
148
149
                m.weight.data = init.xavier_uniform_(
                    m.weight.data, gain=nn.init.calculate_gain("relu")
                )
150
151
                if m.bias is not None:
                    m.bias.data = init.constant_(m.bias.data, 0.0)
152

153
154
155
156
157
158
159
160
161
162
163
    def gcn_forward(self, g, h, gc_layers, cat=False):
        """
        Return gc_layer embedding cat.
        """
        block_readout = []
        for gc_layer in gc_layers[:-1]:
            h = gc_layer(g, h)
            block_readout.append(h)
        h = gc_layers[-1](g, h)
        block_readout.append(h)
        if cat:
164
            block = torch.cat(block_readout, dim=1)  # N x F, F = F1 + F2 + ...
165
166
167
        else:
            block = h
        return block
168

169
170
171
172
173
174
    def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):
        block_readout = []
        for gc_layer in gc_layers:
            h = gc_layer(h, adj)
            block_readout.append(h)
        if cat:
175
            block = torch.cat(block_readout, dim=2)  # N x F, F = F1 + F2 + ...
176
177
178
        else:
            block = h
        return block
179

180
181
182
    def forward(self, g):
        self.link_pred_loss = []
        self.entropy_loss = []
183
        h = g.ndata["feat"]
184
185
186
187
188
189
190
191
192
        # node feature for assignment matrix computation is the same as the
        # original node feature
        h_a = h

        out_all = []

        # we use GCN blocks to get an embedding first
        g_embedding = self.gcn_forward(g, h, self.gc_before_pool, self.concat)

193
        g.ndata["h"] = g_embedding
194

195
        readout = dgl.sum_nodes(g, "h")
196
197
        out_all.append(readout)
        if self.num_aggs == 2:
198
            readout = dgl.max_nodes(g, "h")
199
            out_all.append(readout)
200

201
        adj, h = self.first_diffpool_layer(g, g_embedding)
202
        node_per_pool_graph = int(adj.size()[0] / len(g.batch_num_nodes()))
203

204
        h, adj = batch2tensor(adj, h, node_per_pool_graph)
205
        h = self.gcn_forward_tensorized(
206
207
            h, adj, self.gc_after_pool[0], self.concat
        )
208
209
210
211
212
213
214
215
        readout = torch.sum(h, dim=1)
        out_all.append(readout)
        if self.num_aggs == 2:
            readout, _ = torch.max(h, dim=1)
            out_all.append(readout)

        for i, diffpool_layer in enumerate(self.diffpool_layers):
            h, adj = diffpool_layer(h, adj)
216
            h = self.gcn_forward_tensorized(
217
218
                h, adj, self.gc_after_pool[i + 1], self.concat
            )
219
220
221
222
223
224
225
226
227
228
229
            readout = torch.sum(h, dim=1)
            out_all.append(readout)
            if self.num_aggs == 2:
                readout, _ = torch.max(h, dim=1)
                out_all.append(readout)
        if self.concat or self.num_aggs > 1:
            final_readout = torch.cat(out_all, dim=1)
        else:
            final_readout = readout
        ypred = self.pred_layer(final_readout)
        return ypred
230

231
    def loss(self, pred, label):
232
        """
233
        loss function
234
235
        """
        # softmax + CE
236
237
        criterion = nn.CrossEntropyLoss()
        loss = criterion(pred, label)
Peiqi Yin's avatar
Peiqi Yin committed
238
239
        for key, value in self.first_diffpool_layer.loss_log.items():
            loss += value
240
241
        for diffpool_layer in self.diffpool_layers:
            for key, value in diffpool_layer.loss_log.items():
242
                loss += value
243
        return loss