"vscode:/vscode.git/clone" did not exist on "1f8a8c1dc95a475f0b6fd3cc082e89a674c072d7"
hetero_rgcn.py 10.8 KB
Newer Older
1
2
3
4
5
6
import argparse
import itertools
from tqdm import tqdm

import dgl
import dgl.nn as dglnn
YJ-Zhao's avatar
YJ-Zhao committed
7
8
from dgl.nn import HeteroEmbedding
from dgl import Compose, AddReverse, ToSimple
9
10
11
12
13
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator

YJ-Zhao's avatar
YJ-Zhao committed
14
15
16
17
18
19
def prepare_data(args):
    dataset = DglNodePropPredDataset(name="ogbn-mag")
    split_idx = dataset.get_idx_split()
    # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
    g, labels = dataset[0]
    labels = labels['paper'].flatten()
20

YJ-Zhao's avatar
YJ-Zhao committed
21
22
    transform = Compose([ToSimple(), AddReverse()])
    g = transform(g)
23

YJ-Zhao's avatar
YJ-Zhao committed
24
    print("Loaded graph: {}".format(g))
25

YJ-Zhao's avatar
YJ-Zhao committed
26
    logger = Logger(args.runs)
27

YJ-Zhao's avatar
YJ-Zhao committed
28
29
30
31
32
33
34
    # train sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
    train_loader = dgl.dataloading.DataLoader(
        g, split_idx['train'], sampler,
        batch_size=1024, shuffle=True, num_workers=0)

    return g, labels, dataset.num_classes, split_idx, logger, train_loader
35

YJ-Zhao's avatar
YJ-Zhao committed
36
37
38
39
40
def extract_embed(node_embed, input_nodes):
    emb = node_embed({
        ntype: input_nodes[ntype] for ntype in input_nodes if ntype != 'paper'
    })
    return emb
41

YJ-Zhao's avatar
YJ-Zhao committed
42
43
44
45
46
47
48
49
def rel_graph_embed(graph, embed_size):
    node_num = {}
    for ntype in graph.ntypes:
        if ntype == 'paper':
            continue
        node_num[ntype] = graph.num_nodes(ntype)
    embeds = HeteroEmbedding(node_num, embed_size)
    return embeds
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

class RelGraphConvLayer(nn.Module):
    def __init__(self,
                 in_feat,
                 out_feat,
                 ntypes,
                 rel_names,
                 activation=None,
                 dropout=0.0):
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.ntypes = ntypes
        self.rel_names = rel_names
        self.activation = activation

        self.conv = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
                for rel in rel_names
            })

YJ-Zhao's avatar
YJ-Zhao committed
71
72
73
74
        self.weight = nn.ModuleDict({
            rel_name: nn.Linear(in_feat, out_feat, bias=False)
            for rel_name in self.rel_names
        })
75
76

        # weight for self loop
YJ-Zhao's avatar
YJ-Zhao committed
77
78
79
80
        self.loop_weights = nn.ModuleDict({
            ntype: nn.Linear(in_feat, out_feat, bias=True)
            for ntype in self.ntypes
        })
81
82
83
84
85

        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
YJ-Zhao's avatar
YJ-Zhao committed
86
87
        for layer in self.weight.values():
            layer.reset_parameters()
88

YJ-Zhao's avatar
YJ-Zhao committed
89
90
        for layer in self.loop_weights.values():
            layer.reset_parameters()
91
92

    def forward(self, g, inputs):
YJ-Zhao's avatar
YJ-Zhao committed
93
        """
94
95
96
97
98
99
100
101
102
103
104
105
106
        Parameters
        ----------
        g : DGLHeteroGraph
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
YJ-Zhao's avatar
YJ-Zhao committed
107
108
        wdict = {rel_name: {'weight': self.weight[rel_name].weight.T}
                 for rel_name in self.rel_names}
109

YJ-Zhao's avatar
YJ-Zhao committed
110
        inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
111
112
113
114

        hs = self.conv(g, inputs, mod_kwargs=wdict)

        def _apply(ntype, h):
YJ-Zhao's avatar
YJ-Zhao committed
115
            h = h + self.loop_weights[ntype](inputs_dst[ntype])
116
117
118
119
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)

YJ-Zhao's avatar
YJ-Zhao committed
120
        return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
121
122

class EntityClassify(nn.Module):
YJ-Zhao's avatar
YJ-Zhao committed
123
    def __init__(self, g, in_dim, out_dim):
124
125
        super(EntityClassify, self).__init__()
        self.in_dim = in_dim
YJ-Zhao's avatar
YJ-Zhao committed
126
        self.h_dim = 64
127
128
129
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
YJ-Zhao's avatar
YJ-Zhao committed
130
        self.dropout = 0.5
131
132
133
134
135

        self.layers = nn.ModuleList()
        # i2h
        self.layers.append(RelGraphConvLayer(
            self.in_dim, self.h_dim, g.ntypes, self.rel_names,
YJ-Zhao's avatar
YJ-Zhao committed
136
137
            activation=F.relu, dropout=self.dropout))

138
139
140
        # h2o
        self.layers.append(RelGraphConvLayer(
            self.h_dim, self.out_dim, g.ntypes, self.rel_names,
YJ-Zhao's avatar
YJ-Zhao committed
141
            activation=None))
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, h, blocks):
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h)
        return h

class Logger(object):
    r"""
    This class was taken directly from the PyG implementation and can be found
    here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py

    This was done to ensure that performance was measured in precisely the same way
    """
YJ-Zhao's avatar
YJ-Zhao committed
159
    def __init__(self, runs):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * th.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.2f}')
            print(f'Highest Valid: {result[:, 1].max():.2f}')
            print(f'  Final Train: {result[argmax, 0]:.2f}')
            print(f'   Final Test: {result[argmax, 2]:.2f}')
        else:
            result = 100 * th.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = th.tensor(best_results)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 2]
            print(f'  Final Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 3]
            print(f'   Final Test: {r.mean():.2f} ± {r.std():.2f}')

YJ-Zhao's avatar
YJ-Zhao committed
199
200
def train(g, model, node_embed, optimizer, train_loader, split_idx,
          labels, logger, device, run):
201
202
203
    print("start training...")
    category = 'paper'

YJ-Zhao's avatar
YJ-Zhao committed
204
205
206
    for epoch in range(3):
        num_train = split_idx['train'][category].shape[0]
        pbar = tqdm(total=num_train)
207
208
        pbar.set_description(f'Epoch {epoch:02d}')
        model.train()
YJ-Zhao's avatar
YJ-Zhao committed
209

210
211
212
213
214
215
216
217
218
219
        total_loss = 0

        for input_nodes, seeds, blocks in train_loader:
            blocks = [blk.to(device) for blk in blocks]
            seeds = seeds[category]     # we only predict the nodes with type "category"
            batch_size = seeds.shape[0]

            emb = extract_embed(node_embed, input_nodes)
            # Add the batch's raw "paper" features
            emb.update({'paper': g.ndata['feat']['paper'][input_nodes['paper']]})
YJ-Zhao's avatar
YJ-Zhao committed
220
221
222
223

            emb = {k : e.to(device) for k, e in emb.items()}
            lbl = labels[seeds].to(device)

224
225
            optimizer.zero_grad()
            logits = model(emb, blocks)[category]
YJ-Zhao's avatar
YJ-Zhao committed
226

227
228
229
230
            y_hat = logits.log_softmax(dim=-1)
            loss = F.nll_loss(y_hat, lbl)
            loss.backward()
            optimizer.step()
YJ-Zhao's avatar
YJ-Zhao committed
231

232
233
            total_loss += loss.item() * batch_size
            pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
234

235
        pbar.close()
YJ-Zhao's avatar
YJ-Zhao committed
236
237
238
        loss = total_loss / num_train

        result = test(g, model, node_embed, labels, device, split_idx)
239
240
241
242
243
244
245
246
        logger.add_result(run, result)
        train_acc, valid_acc, test_acc = result
        print(f'Run: {run + 1:02d}, '
              f'Epoch: {epoch +1 :02d}, '
              f'Loss: {loss:.4f}, '
              f'Train: {100 * train_acc:.2f}%, '
              f'Valid: {100 * valid_acc:.2f}%, '
              f'Test: {100 * test_acc:.2f}%')
YJ-Zhao's avatar
YJ-Zhao committed
247

248
249
250
    return logger

@th.no_grad()
YJ-Zhao's avatar
YJ-Zhao committed
251
def test(g, model, node_embed, y_true, device, split_idx):
252
253
254
255
    model.eval()
    category = 'paper'
    evaluator = Evaluator(name='ogbn-mag')

YJ-Zhao's avatar
YJ-Zhao committed
256
257
    # 2 GNN layers
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
258
    loader = dgl.dataloading.DataLoader(
259
260
        g, {'paper': th.arange(g.num_nodes('paper'))}, sampler,
        batch_size=16384, shuffle=False, num_workers=0)
YJ-Zhao's avatar
YJ-Zhao committed
261
262
263
264

    pbar = tqdm(total=y_true.size(0))
    pbar.set_description(f'Inference')

265
266
267
268
269
270
271
272
273
274
    y_hats = list()

    for input_nodes, seeds, blocks in loader:
        blocks = [blk.to(device) for blk in blocks]
        seeds = seeds[category]     # we only predict the nodes with type "category"
        batch_size = seeds.shape[0]

        emb = extract_embed(node_embed, input_nodes)
        # Get the batch's raw "paper" features
        emb.update({'paper': g.ndata['feat']['paper'][input_nodes['paper']]})
YJ-Zhao's avatar
YJ-Zhao committed
275
276
        emb = {k : e.to(device) for k, e in emb.items()}

277
278
279
        logits = model(emb, blocks)[category]
        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
        y_hats.append(y_hat.cpu())
YJ-Zhao's avatar
YJ-Zhao committed
280

281
        pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
282

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    pbar.close()

    y_pred = th.cat(y_hats, dim=0)
    y_true = th.unsqueeze(y_true, 1)

    train_acc = evaluator.eval({
        'y_true': y_true[split_idx['train']['paper']],
        'y_pred': y_pred[split_idx['train']['paper']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': y_true[split_idx['valid']['paper']],
        'y_pred': y_pred[split_idx['valid']['paper']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': y_true[split_idx['test']['paper']],
        'y_pred': y_pred[split_idx['test']['paper']],
    })['acc']

    return train_acc, valid_acc, test_acc

def main(args):
    device = f'cuda:0' if th.cuda.is_available() else 'cpu'

YJ-Zhao's avatar
YJ-Zhao committed
306
307
308
309
    g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args)

    embed_layer = rel_graph_embed(g, 128)
    model = EntityClassify(g, 128, num_classes).to(device)
310

YJ-Zhao's avatar
YJ-Zhao committed
311
312
313
314
    print(f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}")
    print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())}")

    for run in range(args.runs):
315
316
317
318
319
320

        embed_layer.reset_parameters()
        model.reset_parameters()

        # optimizer
        all_params = itertools.chain(model.parameters(), embed_layer.parameters())
YJ-Zhao's avatar
YJ-Zhao committed
321
        optimizer = th.optim.Adam(all_params, lr=0.01)
322

YJ-Zhao's avatar
YJ-Zhao committed
323
324
        logger = train(g, model, embed_layer, optimizer, train_loader, split_idx,
              labels, logger, device, run)
325
        logger.print_statistics(run)
YJ-Zhao's avatar
YJ-Zhao committed
326

327
328
329
330
    print("Final performance: ")
    logger.print_statistics()

if __name__ == '__main__':
YJ-Zhao's avatar
YJ-Zhao committed
331
332
333
334
335
    parser = argparse.ArgumentParser(description='RGCN')
    parser.add_argument('--runs', type=int, default=10)

    args = parser.parse_args()

336
    main(args)