hetero_rgcn.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import itertools

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
8
9
10
11
12
13
14
from tqdm import tqdm

import dgl
import dgl.nn as dglnn
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding

15
16
import psutil
import sys
17

18
19
20
v_t = dgl.__version__

def prepare_data(args, device):
YJ-Zhao's avatar
YJ-Zhao committed
21
22
23
24
    dataset = DglNodePropPredDataset(name="ogbn-mag")
    split_idx = dataset.get_idx_split()
    # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
    g, labels = dataset[0]
25
    labels = labels["paper"].flatten()
26

YJ-Zhao's avatar
YJ-Zhao committed
27
28
    transform = Compose([ToSimple(), AddReverse()])
    g = transform(g)
29

YJ-Zhao's avatar
YJ-Zhao committed
30
    print("Loaded graph: {}".format(g))
31

YJ-Zhao's avatar
YJ-Zhao committed
32
    logger = Logger(args.runs)
33

YJ-Zhao's avatar
YJ-Zhao committed
34
35
    # train sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
36
    num_workers = args.num_workers
YJ-Zhao's avatar
YJ-Zhao committed
37
    train_loader = dgl.dataloading.DataLoader(
38
39
40
41
42
        g,
        split_idx["train"],
        sampler,
        batch_size=1024,
        shuffle=True,
43
44
        num_workers=num_workers,
        device=device
45
    )
YJ-Zhao's avatar
YJ-Zhao committed
46
47

    return g, labels, dataset.num_classes, split_idx, logger, train_loader
48

49

YJ-Zhao's avatar
YJ-Zhao committed
50
def extract_embed(node_embed, input_nodes):
51
52
53
    emb = node_embed(
        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
    )
YJ-Zhao's avatar
YJ-Zhao committed
54
    return emb
55

56

YJ-Zhao's avatar
YJ-Zhao committed
57
58
59
def rel_graph_embed(graph, embed_size):
    node_num = {}
    for ntype in graph.ntypes:
60
        if ntype == "paper":
YJ-Zhao's avatar
YJ-Zhao committed
61
62
63
64
            continue
        node_num[ntype] = graph.num_nodes(ntype)
    embeds = HeteroEmbedding(node_num, embed_size)
    return embeds
65

66

67
class RelGraphConvLayer(nn.Module):
68
69
70
    def __init__(
        self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0
    ):
71
72
73
74
75
76
77
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.ntypes = ntypes
        self.rel_names = rel_names
        self.activation = activation

78
79
80
81
82
        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
83
                for rel in rel_names
84
85
            }
        )
86

87
88
89
90
91
92
        self.weight = nn.ModuleDict(
            {
                rel_name: nn.Linear(in_feat, out_feat, bias=False)
                for rel_name in self.rel_names
            }
        )
93
94

        # weight for self loop
95
96
97
98
99
100
        self.loop_weights = nn.ModuleDict(
            {
                ntype: nn.Linear(in_feat, out_feat, bias=True)
                for ntype in self.ntypes
            }
        )
101
102
103
104
105

        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
YJ-Zhao's avatar
YJ-Zhao committed
106
107
        for layer in self.weight.values():
            layer.reset_parameters()
108

YJ-Zhao's avatar
YJ-Zhao committed
109
110
        for layer in self.loop_weights.values():
            layer.reset_parameters()
111
112

    def forward(self, g, inputs):
YJ-Zhao's avatar
YJ-Zhao committed
113
        """
114
115
116
117
118
119
120
121
122
123
124
125
126
        Parameters
        ----------
        g : DGLHeteroGraph
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
127
128
129
130
        wdict = {
            rel_name: {"weight": self.weight[rel_name].weight.T}
            for rel_name in self.rel_names
        }
131

132
133
134
        inputs_dst = {
            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
        }
135
136
137
138

        hs = self.conv(g, inputs, mod_kwargs=wdict)

        def _apply(ntype, h):
YJ-Zhao's avatar
YJ-Zhao committed
139
            h = h + self.loop_weights[ntype](inputs_dst[ntype])
140
141
142
143
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)

144
145
        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

146
147

class EntityClassify(nn.Module):
YJ-Zhao's avatar
YJ-Zhao committed
148
    def __init__(self, g, in_dim, out_dim):
149
150
        super(EntityClassify, self).__init__()
        self.in_dim = in_dim
YJ-Zhao's avatar
YJ-Zhao committed
151
        self.h_dim = 64
152
153
154
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
YJ-Zhao's avatar
YJ-Zhao committed
155
        self.dropout = 0.5
156
157
158

        self.layers = nn.ModuleList()
        # i2h
159
160
161
162
163
164
165
166
167
168
        self.layers.append(
            RelGraphConvLayer(
                self.in_dim,
                self.h_dim,
                g.ntypes,
                self.rel_names,
                activation=F.relu,
                dropout=self.dropout,
            )
        )
YJ-Zhao's avatar
YJ-Zhao committed
169

170
        # h2o
171
172
173
174
175
176
177
178
179
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.out_dim,
                g.ntypes,
                self.rel_names,
                activation=None,
            )
        )
180
181
182
183
184
185
186
187
188
189

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, h, blocks):
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h)
        return h

190

191
192
193
194
195
196
197
class Logger(object):
    r"""
    This class was taken directly from the PyG implementation and can be found
    here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py

    This was done to ensure that performance was measured in precisely the same way
    """
198

YJ-Zhao's avatar
YJ-Zhao committed
199
    def __init__(self, runs):
200
201
202
203
204
205
206
207
208
209
210
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * th.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
211
212
213
214
215
            print(f"Run {run + 1:02d}:")
            print(f"Highest Train: {result[:, 0].max():.2f}")
            print(f"Highest Valid: {result[:, 1].max():.2f}")
            print(f"  Final Train: {result[argmax, 0]:.2f}")
            print(f"   Final Test: {result[argmax, 2]:.2f}")
216
217
218
219
220
221
222
223
224
225
226
227
228
        else:
            result = 100 * th.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = th.tensor(best_results)

229
            print(f"All runs:")
230
            r = best_result[:, 0]
231
            print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
232
            r = best_result[:, 1]
233
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
234
            r = best_result[:, 2]
235
            print(f"  Final Train: {r.mean():.2f} ± {r.std():.2f}")
236
            r = best_result[:, 3]
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}")


def train(
    g,
    model,
    node_embed,
    optimizer,
    train_loader,
    split_idx,
    labels,
    logger,
    device,
    run,
):
252
    print("start training...")
253
    category = "paper"
254

YJ-Zhao's avatar
YJ-Zhao committed
255
    for epoch in range(3):
256
        num_train = split_idx["train"][category].shape[0]
YJ-Zhao's avatar
YJ-Zhao committed
257
        pbar = tqdm(total=num_train)
258
        pbar.set_description(f"Epoch {epoch:02d}")
259
        model.train()
YJ-Zhao's avatar
YJ-Zhao committed
260

261
262
263
264
        total_loss = 0

        for input_nodes, seeds, blocks in train_loader:
            blocks = [blk.to(device) for blk in blocks]
265
266
267
            seeds = seeds[
                category
            ]  # we only predict the nodes with type "category"
268
269
270
271
            batch_size = seeds.shape[0]

            emb = extract_embed(node_embed, input_nodes)
            # Add the batch's raw "paper" features
272
273
274
            emb.update(
                {"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]}
            )
YJ-Zhao's avatar
YJ-Zhao committed
275

276
            emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
277
278
            lbl = labels[seeds].to(device)

279
280
            optimizer.zero_grad()
            logits = model(emb, blocks)[category]
YJ-Zhao's avatar
YJ-Zhao committed
281

282
283
284
285
            y_hat = logits.log_softmax(dim=-1)
            loss = F.nll_loss(y_hat, lbl)
            loss.backward()
            optimizer.step()
YJ-Zhao's avatar
YJ-Zhao committed
286

287
288
            total_loss += loss.item() * batch_size
            pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
289

290
        pbar.close()
YJ-Zhao's avatar
YJ-Zhao committed
291
292
293
        loss = total_loss / num_train

        result = test(g, model, node_embed, labels, device, split_idx)
294
295
        logger.add_result(run, result)
        train_acc, valid_acc, test_acc = result
296
297
298
299
300
301
302
303
        print(
            f"Run: {run + 1:02d}, "
            f"Epoch: {epoch +1 :02d}, "
            f"Loss: {loss:.4f}, "
            f"Train: {100 * train_acc:.2f}%, "
            f"Valid: {100 * valid_acc:.2f}%, "
            f"Test: {100 * test_acc:.2f}%"
        )
YJ-Zhao's avatar
YJ-Zhao committed
304

305
306
    return logger

307

308
@th.no_grad()
YJ-Zhao's avatar
YJ-Zhao committed
309
def test(g, model, node_embed, y_true, device, split_idx):
310
    model.eval()
311
312
    category = "paper"
    evaluator = Evaluator(name="ogbn-mag")
313

YJ-Zhao's avatar
YJ-Zhao committed
314
315
    # 2 GNN layers
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
316
    loader = dgl.dataloading.DataLoader(
317
318
319
320
321
322
        g,
        {"paper": th.arange(g.num_nodes("paper"))},
        sampler,
        batch_size=16384,
        shuffle=False,
        num_workers=0,
323
        device=device
324
    )
YJ-Zhao's avatar
YJ-Zhao committed
325
326

    pbar = tqdm(total=y_true.size(0))
327
    pbar.set_description(f"Inference")
YJ-Zhao's avatar
YJ-Zhao committed
328

329
330
331
332
    y_hats = list()

    for input_nodes, seeds, blocks in loader:
        blocks = [blk.to(device) for blk in blocks]
333
334
335
        seeds = seeds[
            category
        ]  # we only predict the nodes with type "category"
336
337
338
339
        batch_size = seeds.shape[0]

        emb = extract_embed(node_embed, input_nodes)
        # Get the batch's raw "paper" features
340
341
        emb.update({"paper": g.ndata["feat"]["paper"][input_nodes["paper"]]})
        emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
342

343
344
345
        logits = model(emb, blocks)[category]
        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
        y_hats.append(y_hat.cpu())
YJ-Zhao's avatar
YJ-Zhao committed
346

347
        pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
348

349
350
351
352
353
    pbar.close()

    y_pred = th.cat(y_hats, dim=0)
    y_true = th.unsqueeze(y_true, 1)

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    train_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["train"]["paper"]],
            "y_pred": y_pred[split_idx["train"]["paper"]],
        }
    )["acc"]
    valid_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["valid"]["paper"]],
            "y_pred": y_pred[split_idx["valid"]["paper"]],
        }
    )["acc"]
    test_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["test"]["paper"]],
            "y_pred": y_pred[split_idx["test"]["paper"]],
        }
    )["acc"]
372
373
374

    return train_acc, valid_acc, test_acc

375
376
377
def is_support_affinity(v_t):
    # dgl supports enable_cpu_affinity since 0.9.1
    return v_t >= "0.9.1"
378

379
def main(args):
380
    device = f"cuda:0" if th.cuda.is_available() else "cpu"
381

382
    g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args, device)
YJ-Zhao's avatar
YJ-Zhao committed
383

384
    embed_layer = rel_graph_embed(g, 128).to(device)
YJ-Zhao's avatar
YJ-Zhao committed
385
    model = EntityClassify(g, 128, num_classes).to(device)
386

387
388
389
390
391
392
    print(
        f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}"
    )
    print(
        f"Number of model parameters: {sum(p.numel() for p in model.parameters())}"
    )
YJ-Zhao's avatar
YJ-Zhao committed
393
394

    for run in range(args.runs):
395

396
397
398
399
400
401
        try:
            embed_layer.reset_parameters()
            model.reset_parameters()
        except:
            # old pytorch version doesn't support reset_parameters() API
            pass
402
403

        # optimizer
404
405
406
        all_params = itertools.chain(
            model.parameters(), embed_layer.parameters()
        )
YJ-Zhao's avatar
YJ-Zhao committed
407
        optimizer = th.optim.Adam(all_params, lr=0.01)
408

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
        if args.num_workers != 0 and device == "cpu" and is_support_affinity(v_t):
            expected_max = int(psutil.cpu_count(logical=False))
            if args.num_workers >= expected_max:
                print(f"[ERROR] You specified num_workers are larger than physical cores, please set any number less than {expected_max}", file=sys.stderr)
            with train_loader.enable_cpu_affinity():
                logger = train(
                    g,
                    model,
                    embed_layer,
                    optimizer,
                    train_loader,
                    split_idx,
                    labels,
                    logger,
                    device,
                    run,
                )
        else:
            logger = train(
                g,
                model,
                embed_layer,
                optimizer,
                train_loader,
                split_idx,
                labels,
                logger,
                device,
                run,
            )
439
        logger.print_statistics(run)
YJ-Zhao's avatar
YJ-Zhao committed
440

441
442
443
    print("Final performance: ")
    logger.print_statistics()

444
445
446
447

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument("--runs", type=int, default=10)
448
    parser.add_argument("--num_workers", type=int, default=0)
YJ-Zhao's avatar
YJ-Zhao committed
449
450
451

    args = parser.parse_args()

452
    main(args)