hetero_rgcn.py 12.9 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import itertools

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
8
9
10
11
12
13
14
from tqdm import tqdm

import dgl
import dgl.nn as dglnn
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding

15
16
import psutil
import sys
17

18
19
20
v_t = dgl.__version__

def prepare_data(args, device):
YJ-Zhao's avatar
YJ-Zhao committed
21
22
23
24
    dataset = DglNodePropPredDataset(name="ogbn-mag")
    split_idx = dataset.get_idx_split()
    # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
    g, labels = dataset[0]
25
    labels = labels["paper"].flatten()
26

YJ-Zhao's avatar
YJ-Zhao committed
27
28
    transform = Compose([ToSimple(), AddReverse()])
    g = transform(g)
29

YJ-Zhao's avatar
YJ-Zhao committed
30
    print("Loaded graph: {}".format(g))
31

YJ-Zhao's avatar
YJ-Zhao committed
32
    logger = Logger(args.runs)
33

YJ-Zhao's avatar
YJ-Zhao committed
34
35
    # train sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
36
    num_workers = args.num_workers
YJ-Zhao's avatar
YJ-Zhao committed
37
    train_loader = dgl.dataloading.DataLoader(
38
39
40
41
42
        g,
        split_idx["train"],
        sampler,
        batch_size=1024,
        shuffle=True,
43
44
        num_workers=num_workers,
        device=device
45
    )
YJ-Zhao's avatar
YJ-Zhao committed
46
47

    return g, labels, dataset.num_classes, split_idx, logger, train_loader
48

49

YJ-Zhao's avatar
YJ-Zhao committed
50
def extract_embed(node_embed, input_nodes):
51
52
53
    emb = node_embed(
        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
    )
YJ-Zhao's avatar
YJ-Zhao committed
54
    return emb
55

56

YJ-Zhao's avatar
YJ-Zhao committed
57
58
59
def rel_graph_embed(graph, embed_size):
    node_num = {}
    for ntype in graph.ntypes:
60
        if ntype == "paper":
YJ-Zhao's avatar
YJ-Zhao committed
61
62
63
64
            continue
        node_num[ntype] = graph.num_nodes(ntype)
    embeds = HeteroEmbedding(node_num, embed_size)
    return embeds
65

66

67
class RelGraphConvLayer(nn.Module):
68
69
70
    def __init__(
        self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0
    ):
71
72
73
74
75
76
77
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.ntypes = ntypes
        self.rel_names = rel_names
        self.activation = activation

78
79
80
81
82
        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
83
                for rel in rel_names
84
85
            }
        )
86

87
88
89
90
91
92
        self.weight = nn.ModuleDict(
            {
                rel_name: nn.Linear(in_feat, out_feat, bias=False)
                for rel_name in self.rel_names
            }
        )
93
94

        # weight for self loop
95
96
97
98
99
100
        self.loop_weights = nn.ModuleDict(
            {
                ntype: nn.Linear(in_feat, out_feat, bias=True)
                for ntype in self.ntypes
            }
        )
101
102
103
104
105

        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
YJ-Zhao's avatar
YJ-Zhao committed
106
107
        for layer in self.weight.values():
            layer.reset_parameters()
108

YJ-Zhao's avatar
YJ-Zhao committed
109
110
        for layer in self.loop_weights.values():
            layer.reset_parameters()
111
112

    def forward(self, g, inputs):
YJ-Zhao's avatar
YJ-Zhao committed
113
        """
114
115
        Parameters
        ----------
peizhou001's avatar
peizhou001 committed
116
        g : DGLGraph
117
118
119
120
121
122
123
124
125
126
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
127
128
129
130
        wdict = {
            rel_name: {"weight": self.weight[rel_name].weight.T}
            for rel_name in self.rel_names
        }
131

132
133
134
        inputs_dst = {
            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
        }
135
136
137
138

        hs = self.conv(g, inputs, mod_kwargs=wdict)

        def _apply(ntype, h):
YJ-Zhao's avatar
YJ-Zhao committed
139
            h = h + self.loop_weights[ntype](inputs_dst[ntype])
140
141
142
143
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)

144
145
        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

146
147

class EntityClassify(nn.Module):
YJ-Zhao's avatar
YJ-Zhao committed
148
    def __init__(self, g, in_dim, out_dim):
149
150
        super(EntityClassify, self).__init__()
        self.in_dim = in_dim
YJ-Zhao's avatar
YJ-Zhao committed
151
        self.h_dim = 64
152
153
154
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
YJ-Zhao's avatar
YJ-Zhao committed
155
        self.dropout = 0.5
156
157
158

        self.layers = nn.ModuleList()
        # i2h
159
160
161
162
163
164
165
166
167
168
        self.layers.append(
            RelGraphConvLayer(
                self.in_dim,
                self.h_dim,
                g.ntypes,
                self.rel_names,
                activation=F.relu,
                dropout=self.dropout,
            )
        )
YJ-Zhao's avatar
YJ-Zhao committed
169

170
        # h2o
171
172
173
174
175
176
177
178
179
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.out_dim,
                g.ntypes,
                self.rel_names,
                activation=None,
            )
        )
180
181
182
183
184
185
186
187
188
189

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, h, blocks):
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h)
        return h

190

191
192
193
194
195
196
197
class Logger(object):
    r"""
    This class was taken directly from the PyG implementation and can be found
    here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py

    This was done to ensure that performance was measured in precisely the same way
    """
198

YJ-Zhao's avatar
YJ-Zhao committed
199
    def __init__(self, runs):
200
201
202
203
204
205
206
207
208
209
210
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * th.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
211
212
213
214
215
            print(f"Run {run + 1:02d}:")
            print(f"Highest Train: {result[:, 0].max():.2f}")
            print(f"Highest Valid: {result[:, 1].max():.2f}")
            print(f"  Final Train: {result[argmax, 0]:.2f}")
            print(f"   Final Test: {result[argmax, 2]:.2f}")
216
217
218
219
220
221
222
223
224
225
226
227
228
        else:
            result = 100 * th.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = th.tensor(best_results)

229
            print(f"All runs:")
230
            r = best_result[:, 0]
231
            print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
232
            r = best_result[:, 1]
233
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
234
            r = best_result[:, 2]
235
            print(f"  Final Train: {r.mean():.2f} ± {r.std():.2f}")
236
            r = best_result[:, 3]
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}")


def train(
    g,
    model,
    node_embed,
    optimizer,
    train_loader,
    split_idx,
    labels,
    logger,
    device,
    run,
):
252
    print("start training...")
253
    category = "paper"
254

YJ-Zhao's avatar
YJ-Zhao committed
255
    for epoch in range(3):
256
        num_train = split_idx["train"][category].shape[0]
YJ-Zhao's avatar
YJ-Zhao committed
257
        pbar = tqdm(total=num_train)
258
        pbar.set_description(f"Epoch {epoch:02d}")
259
        model.train()
YJ-Zhao's avatar
YJ-Zhao committed
260

261
262
263
264
        total_loss = 0

        for input_nodes, seeds, blocks in train_loader:
            blocks = [blk.to(device) for blk in blocks]
265
266
267
            seeds = seeds[
                category
            ]  # we only predict the nodes with type "category"
268
            batch_size = seeds.shape[0]
269
270
            input_nodes_indexes = input_nodes["paper"].to(g.device)
            seeds = seeds.to(labels.device)
271
272
273

            emb = extract_embed(node_embed, input_nodes)
            # Add the batch's raw "paper" features
274
            emb.update(
275
                {"paper": g.ndata["feat"]["paper"][input_nodes_indexes]}
276
            )
YJ-Zhao's avatar
YJ-Zhao committed
277

278
            emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
279
280
            lbl = labels[seeds].to(device)

281
282
            optimizer.zero_grad()
            logits = model(emb, blocks)[category]
YJ-Zhao's avatar
YJ-Zhao committed
283

284
285
286
287
            y_hat = logits.log_softmax(dim=-1)
            loss = F.nll_loss(y_hat, lbl)
            loss.backward()
            optimizer.step()
YJ-Zhao's avatar
YJ-Zhao committed
288

289
290
            total_loss += loss.item() * batch_size
            pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
291

292
        pbar.close()
YJ-Zhao's avatar
YJ-Zhao committed
293
294
295
        loss = total_loss / num_train

        result = test(g, model, node_embed, labels, device, split_idx)
296
297
        logger.add_result(run, result)
        train_acc, valid_acc, test_acc = result
298
299
300
301
302
303
304
305
        print(
            f"Run: {run + 1:02d}, "
            f"Epoch: {epoch +1 :02d}, "
            f"Loss: {loss:.4f}, "
            f"Train: {100 * train_acc:.2f}%, "
            f"Valid: {100 * valid_acc:.2f}%, "
            f"Test: {100 * test_acc:.2f}%"
        )
YJ-Zhao's avatar
YJ-Zhao committed
306

307
308
    return logger

309

310
@th.no_grad()
YJ-Zhao's avatar
YJ-Zhao committed
311
def test(g, model, node_embed, y_true, device, split_idx):
312
    model.eval()
313
314
    category = "paper"
    evaluator = Evaluator(name="ogbn-mag")
315

YJ-Zhao's avatar
YJ-Zhao committed
316
317
    # 2 GNN layers
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
318
    loader = dgl.dataloading.DataLoader(
319
320
321
322
323
324
        g,
        {"paper": th.arange(g.num_nodes("paper"))},
        sampler,
        batch_size=16384,
        shuffle=False,
        num_workers=0,
325
        device=device
326
    )
YJ-Zhao's avatar
YJ-Zhao committed
327
328

    pbar = tqdm(total=y_true.size(0))
329
    pbar.set_description(f"Inference")
YJ-Zhao's avatar
YJ-Zhao committed
330

331
332
333
334
    y_hats = list()

    for input_nodes, seeds, blocks in loader:
        blocks = [blk.to(device) for blk in blocks]
335
336
337
        seeds = seeds[
            category
        ]  # we only predict the nodes with type "category"
338
        batch_size = seeds.shape[0]
339
        input_nodes_indexes = input_nodes["paper"].to(g.device)
340
341
342

        emb = extract_embed(node_embed, input_nodes)
        # Get the batch's raw "paper" features
343
        emb.update({"paper": g.ndata["feat"]["paper"][input_nodes_indexes]})
344
        emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
345

346
347
348
        logits = model(emb, blocks)[category]
        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
        y_hats.append(y_hat.cpu())
YJ-Zhao's avatar
YJ-Zhao committed
349

350
        pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
351

352
353
354
355
356
    pbar.close()

    y_pred = th.cat(y_hats, dim=0)
    y_true = th.unsqueeze(y_true, 1)

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    train_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["train"]["paper"]],
            "y_pred": y_pred[split_idx["train"]["paper"]],
        }
    )["acc"]
    valid_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["valid"]["paper"]],
            "y_pred": y_pred[split_idx["valid"]["paper"]],
        }
    )["acc"]
    test_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["test"]["paper"]],
            "y_pred": y_pred[split_idx["test"]["paper"]],
        }
    )["acc"]
375
376
377

    return train_acc, valid_acc, test_acc

378
379
380
def is_support_affinity(v_t):
    # dgl supports enable_cpu_affinity since 0.9.1
    return v_t >= "0.9.1"
381

382
def main(args):
383
    device = f"cuda:0" if th.cuda.is_available() else "cpu"
384

385
    g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args, device)
YJ-Zhao's avatar
YJ-Zhao committed
386

387
    embed_layer = rel_graph_embed(g, 128).to(device)
YJ-Zhao's avatar
YJ-Zhao committed
388
    model = EntityClassify(g, 128, num_classes).to(device)
389

390
391
392
393
394
395
    print(
        f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}"
    )
    print(
        f"Number of model parameters: {sum(p.numel() for p in model.parameters())}"
    )
YJ-Zhao's avatar
YJ-Zhao committed
396
397

    for run in range(args.runs):
398

399
400
401
402
403
404
        try:
            embed_layer.reset_parameters()
            model.reset_parameters()
        except:
            # old pytorch version doesn't support reset_parameters() API
            pass
405
406

        # optimizer
407
408
409
        all_params = itertools.chain(
            model.parameters(), embed_layer.parameters()
        )
YJ-Zhao's avatar
YJ-Zhao committed
410
        optimizer = th.optim.Adam(all_params, lr=0.01)
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        if args.num_workers != 0 and device == "cpu" and is_support_affinity(v_t):
            expected_max = int(psutil.cpu_count(logical=False))
            if args.num_workers >= expected_max:
                print(f"[ERROR] You specified num_workers are larger than physical cores, please set any number less than {expected_max}", file=sys.stderr)
            with train_loader.enable_cpu_affinity():
                logger = train(
                    g,
                    model,
                    embed_layer,
                    optimizer,
                    train_loader,
                    split_idx,
                    labels,
                    logger,
                    device,
                    run,
                )
        else:
            logger = train(
                g,
                model,
                embed_layer,
                optimizer,
                train_loader,
                split_idx,
                labels,
                logger,
                device,
                run,
            )
442
        logger.print_statistics(run)
YJ-Zhao's avatar
YJ-Zhao committed
443

444
445
446
    print("Final performance: ")
    logger.print_statistics()

447
448
449
450

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument("--runs", type=int, default=10)
451
    parser.add_argument("--num_workers", type=int, default=0)
YJ-Zhao's avatar
YJ-Zhao committed
452
453
454

    args = parser.parse_args()

455
    main(args)