hetero_rgcn.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import itertools

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
8
9
10
11
12
13
14
from tqdm import tqdm

import dgl
import dgl.nn as dglnn
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding

15

YJ-Zhao's avatar
YJ-Zhao committed
16
17
18
19
20
def prepare_data(args):
    dataset = DglNodePropPredDataset(name="ogbn-mag")
    split_idx = dataset.get_idx_split()
    # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
    g, labels = dataset[0]
21
    labels = labels["paper"].flatten()
22

YJ-Zhao's avatar
YJ-Zhao committed
23
24
    transform = Compose([ToSimple(), AddReverse()])
    g = transform(g)
25

YJ-Zhao's avatar
YJ-Zhao committed
26
    print("Loaded graph: {}".format(g))
27

YJ-Zhao's avatar
YJ-Zhao committed
28
    logger = Logger(args.runs)
29

YJ-Zhao's avatar
YJ-Zhao committed
30
31
32
    # train sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
    train_loader = dgl.dataloading.DataLoader(
33
34
35
36
37
38
39
        g,
        split_idx["train"],
        sampler,
        batch_size=1024,
        shuffle=True,
        num_workers=0,
    )
YJ-Zhao's avatar
YJ-Zhao committed
40
41

    return g, labels, dataset.num_classes, split_idx, logger, train_loader
42

43

YJ-Zhao's avatar
YJ-Zhao committed
44
def extract_embed(node_embed, input_nodes):
45
46
47
    emb = node_embed(
        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
    )
YJ-Zhao's avatar
YJ-Zhao committed
48
    return emb
49

50

YJ-Zhao's avatar
YJ-Zhao committed
51
52
53
def rel_graph_embed(graph, embed_size):
    node_num = {}
    for ntype in graph.ntypes:
54
        if ntype == "paper":
YJ-Zhao's avatar
YJ-Zhao committed
55
56
57
58
            continue
        node_num[ntype] = graph.num_nodes(ntype)
    embeds = HeteroEmbedding(node_num, embed_size)
    return embeds
59

60

61
class RelGraphConvLayer(nn.Module):
62
63
64
    def __init__(
        self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0
    ):
65
66
67
68
69
70
71
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.ntypes = ntypes
        self.rel_names = rel_names
        self.activation = activation

72
73
74
75
76
        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
77
                for rel in rel_names
78
79
            }
        )
80

81
82
83
84
85
86
        self.weight = nn.ModuleDict(
            {
                rel_name: nn.Linear(in_feat, out_feat, bias=False)
                for rel_name in self.rel_names
            }
        )
87
88

        # weight for self loop
89
90
91
92
93
94
        self.loop_weights = nn.ModuleDict(
            {
                ntype: nn.Linear(in_feat, out_feat, bias=True)
                for ntype in self.ntypes
            }
        )
95
96
97
98
99

        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
YJ-Zhao's avatar
YJ-Zhao committed
100
101
        for layer in self.weight.values():
            layer.reset_parameters()
102

YJ-Zhao's avatar
YJ-Zhao committed
103
104
        for layer in self.loop_weights.values():
            layer.reset_parameters()
105
106

    def forward(self, g, inputs):
YJ-Zhao's avatar
YJ-Zhao committed
107
        """
108
109
110
111
112
113
114
115
116
117
118
119
120
        Parameters
        ----------
        g : DGLHeteroGraph
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
121
122
123
124
        wdict = {
            rel_name: {"weight": self.weight[rel_name].weight.T}
            for rel_name in self.rel_names
        }
125

126
127
128
        inputs_dst = {
            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
        }
129
130
131
132

        hs = self.conv(g, inputs, mod_kwargs=wdict)

        def _apply(ntype, h):
YJ-Zhao's avatar
YJ-Zhao committed
133
            h = h + self.loop_weights[ntype](inputs_dst[ntype])
134
135
136
137
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)

138
139
        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

140
141

class EntityClassify(nn.Module):
YJ-Zhao's avatar
YJ-Zhao committed
142
    def __init__(self, g, in_dim, out_dim):
143
144
        super(EntityClassify, self).__init__()
        self.in_dim = in_dim
YJ-Zhao's avatar
YJ-Zhao committed
145
        self.h_dim = 64
146
147
148
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
YJ-Zhao's avatar
YJ-Zhao committed
149
        self.dropout = 0.5
150
151
152

        self.layers = nn.ModuleList()
        # i2h
153
154
155
156
157
158
159
160
161
162
        self.layers.append(
            RelGraphConvLayer(
                self.in_dim,
                self.h_dim,
                g.ntypes,
                self.rel_names,
                activation=F.relu,
                dropout=self.dropout,
            )
        )
YJ-Zhao's avatar
YJ-Zhao committed
163

164
        # h2o
165
166
167
168
169
170
171
172
173
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.out_dim,
                g.ntypes,
                self.rel_names,
                activation=None,
            )
        )
174
175
176
177
178
179
180
181
182
183

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, h, blocks):
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h)
        return h

184

185
186
187
188
189
190
191
class Logger(object):
    r"""
    This class was taken directly from the PyG implementation and can be found
    here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py

    This was done to ensure that performance was measured in precisely the same way
    """
192

YJ-Zhao's avatar
YJ-Zhao committed
193
    def __init__(self, runs):
194
195
196
197
198
199
200
201
202
203
204
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * th.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
205
206
207
208
209
            print(f"Run {run + 1:02d}:")
            print(f"Highest Train: {result[:, 0].max():.2f}")
            print(f"Highest Valid: {result[:, 1].max():.2f}")
            print(f"  Final Train: {result[argmax, 0]:.2f}")
            print(f"   Final Test: {result[argmax, 2]:.2f}")
210
211
212
213
214
215
216
217
218
219
220
221
222
        else:
            result = 100 * th.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = th.tensor(best_results)

223
            print(f"All runs:")
224
            r = best_result[:, 0]
225
            print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
226
            r = best_result[:, 1]
227
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
228
            r = best_result[:, 2]
229
            print(f"  Final Train: {r.mean():.2f} ± {r.std():.2f}")
230
            r = best_result[:, 3]
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}")


def train(
    g,
    model,
    node_embed,
    optimizer,
    train_loader,
    split_idx,
    labels,
    logger,
    device,
    run,
):
246
    print("start training...")
247
    category = "paper"
248

YJ-Zhao's avatar
YJ-Zhao committed
249
    for epoch in range(3):
250
        num_train = split_idx["train"][category].shape[0]
YJ-Zhao's avatar
YJ-Zhao committed
251
        pbar = tqdm(total=num_train)
252
        pbar.set_description(f"Epoch {epoch:02d}")
253
        model.train()
YJ-Zhao's avatar
YJ-Zhao committed
254

255
256
257
258
        total_loss = 0

        for input_nodes, seeds, blocks in train_loader:
            blocks = [blk.to(device) for blk in blocks]
259
260
261
            seeds = seeds[
                category
            ]  # we only predict the nodes with type "category"
262
263
264
265
            batch_size = seeds.shape[0]

            emb = extract_embed(node_embed, input_nodes)
            # Add the batch's raw "paper" features
266
267
268
            emb.update(
                {"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]}
            )
YJ-Zhao's avatar
YJ-Zhao committed
269

270
            emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
271
272
            lbl = labels[seeds].to(device)

273
274
            optimizer.zero_grad()
            logits = model(emb, blocks)[category]
YJ-Zhao's avatar
YJ-Zhao committed
275

276
277
278
279
            y_hat = logits.log_softmax(dim=-1)
            loss = F.nll_loss(y_hat, lbl)
            loss.backward()
            optimizer.step()
YJ-Zhao's avatar
YJ-Zhao committed
280

281
282
            total_loss += loss.item() * batch_size
            pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
283

284
        pbar.close()
YJ-Zhao's avatar
YJ-Zhao committed
285
286
287
        loss = total_loss / num_train

        result = test(g, model, node_embed, labels, device, split_idx)
288
289
        logger.add_result(run, result)
        train_acc, valid_acc, test_acc = result
290
291
292
293
294
295
296
297
        print(
            f"Run: {run + 1:02d}, "
            f"Epoch: {epoch +1 :02d}, "
            f"Loss: {loss:.4f}, "
            f"Train: {100 * train_acc:.2f}%, "
            f"Valid: {100 * valid_acc:.2f}%, "
            f"Test: {100 * test_acc:.2f}%"
        )
YJ-Zhao's avatar
YJ-Zhao committed
298

299
300
    return logger

301

302
@th.no_grad()
YJ-Zhao's avatar
YJ-Zhao committed
303
def test(g, model, node_embed, y_true, device, split_idx):
304
    model.eval()
305
306
    category = "paper"
    evaluator = Evaluator(name="ogbn-mag")
307

YJ-Zhao's avatar
YJ-Zhao committed
308
309
    # 2 GNN layers
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
310
    loader = dgl.dataloading.DataLoader(
311
312
313
314
315
316
317
        g,
        {"paper": th.arange(g.num_nodes("paper"))},
        sampler,
        batch_size=16384,
        shuffle=False,
        num_workers=0,
    )
YJ-Zhao's avatar
YJ-Zhao committed
318
319

    pbar = tqdm(total=y_true.size(0))
320
    pbar.set_description(f"Inference")
YJ-Zhao's avatar
YJ-Zhao committed
321

322
323
324
325
    y_hats = list()

    for input_nodes, seeds, blocks in loader:
        blocks = [blk.to(device) for blk in blocks]
326
327
328
        seeds = seeds[
            category
        ]  # we only predict the nodes with type "category"
329
330
331
332
        batch_size = seeds.shape[0]

        emb = extract_embed(node_embed, input_nodes)
        # Get the batch's raw "paper" features
333
334
        emb.update({"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]})
        emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
335

336
337
338
        logits = model(emb, blocks)[category]
        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
        y_hats.append(y_hat.cpu())
YJ-Zhao's avatar
YJ-Zhao committed
339

340
        pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
341

342
343
344
345
346
    pbar.close()

    y_pred = th.cat(y_hats, dim=0)
    y_true = th.unsqueeze(y_true, 1)

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    train_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["train"]["paper"]],
            "y_pred": y_pred[split_idx["train"]["paper"]],
        }
    )["acc"]
    valid_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["valid"]["paper"]],
            "y_pred": y_pred[split_idx["valid"]["paper"]],
        }
    )["acc"]
    test_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["test"]["paper"]],
            "y_pred": y_pred[split_idx["test"]["paper"]],
        }
    )["acc"]
365
366
367

    return train_acc, valid_acc, test_acc

368

369
def main(args):
370
    device = f"cuda:0" if th.cuda.is_available() else "cpu"
371

YJ-Zhao's avatar
YJ-Zhao committed
372
373
374
375
    g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args)

    embed_layer = rel_graph_embed(g, 128)
    model = EntityClassify(g, 128, num_classes).to(device)
376

377
378
379
380
381
382
    print(
        f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}"
    )
    print(
        f"Number of model parameters: {sum(p.numel() for p in model.parameters())}"
    )
YJ-Zhao's avatar
YJ-Zhao committed
383
384

    for run in range(args.runs):
385
386
387
388
389

        embed_layer.reset_parameters()
        model.reset_parameters()

        # optimizer
390
391
392
        all_params = itertools.chain(
            model.parameters(), embed_layer.parameters()
        )
YJ-Zhao's avatar
YJ-Zhao committed
393
        optimizer = th.optim.Adam(all_params, lr=0.01)
394

395
396
397
398
399
400
401
402
403
404
405
406
        logger = train(
            g,
            model,
            embed_layer,
            optimizer,
            train_loader,
            split_idx,
            labels,
            logger,
            device,
            run,
        )
407
        logger.print_statistics(run)
YJ-Zhao's avatar
YJ-Zhao committed
408

409
410
411
    print("Final performance: ")
    logger.print_statistics()

412
413
414
415

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument("--runs", type=int, default=10)
YJ-Zhao's avatar
YJ-Zhao committed
416
417
418

    args = parser.parse_args()

419
    main(args)