hetero_rgcn.py 13 KB
Newer Older
1
2
import argparse
import itertools
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
5
6
7
8
import sys

import dgl
import dgl.nn as dglnn

import psutil
9
10
11
12

import torch as th
import torch.nn as nn
import torch.nn.functional as F
13
14
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
16
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from tqdm import tqdm
17

18
19
v_t = dgl.__version__

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20

21
def prepare_data(args, device):
YJ-Zhao's avatar
YJ-Zhao committed
22
23
24
25
    dataset = DglNodePropPredDataset(name="ogbn-mag")
    split_idx = dataset.get_idx_split()
    # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
    g, labels = dataset[0]
26
    labels = labels["paper"].flatten()
27

YJ-Zhao's avatar
YJ-Zhao committed
28
29
    transform = Compose([ToSimple(), AddReverse()])
    g = transform(g)
30

YJ-Zhao's avatar
YJ-Zhao committed
31
    print("Loaded graph: {}".format(g))
32

YJ-Zhao's avatar
YJ-Zhao committed
33
    logger = Logger(args.runs)
34

YJ-Zhao's avatar
YJ-Zhao committed
35
36
    # train sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
37
    num_workers = args.num_workers
YJ-Zhao's avatar
YJ-Zhao committed
38
    train_loader = dgl.dataloading.DataLoader(
39
40
41
42
43
        g,
        split_idx["train"],
        sampler,
        batch_size=1024,
        shuffle=True,
44
        num_workers=num_workers,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
        device=device,
46
    )
YJ-Zhao's avatar
YJ-Zhao committed
47
48

    return g, labels, dataset.num_classes, split_idx, logger, train_loader
49

50

YJ-Zhao's avatar
YJ-Zhao committed
51
def extract_embed(node_embed, input_nodes):
52
53
54
    emb = node_embed(
        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
    )
YJ-Zhao's avatar
YJ-Zhao committed
55
    return emb
56

57

YJ-Zhao's avatar
YJ-Zhao committed
58
59
60
def rel_graph_embed(graph, embed_size):
    node_num = {}
    for ntype in graph.ntypes:
61
        if ntype == "paper":
YJ-Zhao's avatar
YJ-Zhao committed
62
63
64
65
            continue
        node_num[ntype] = graph.num_nodes(ntype)
    embeds = HeteroEmbedding(node_num, embed_size)
    return embeds
66

67

68
class RelGraphConvLayer(nn.Module):
69
70
71
    def __init__(
        self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0
    ):
72
73
74
75
76
77
78
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.ntypes = ntypes
        self.rel_names = rel_names
        self.activation = activation

79
80
81
82
83
        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
84
                for rel in rel_names
85
86
            }
        )
87

88
89
90
91
92
93
        self.weight = nn.ModuleDict(
            {
                rel_name: nn.Linear(in_feat, out_feat, bias=False)
                for rel_name in self.rel_names
            }
        )
94
95

        # weight for self loop
96
97
98
99
100
101
        self.loop_weights = nn.ModuleDict(
            {
                ntype: nn.Linear(in_feat, out_feat, bias=True)
                for ntype in self.ntypes
            }
        )
102
103
104
105
106

        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
YJ-Zhao's avatar
YJ-Zhao committed
107
108
        for layer in self.weight.values():
            layer.reset_parameters()
109

YJ-Zhao's avatar
YJ-Zhao committed
110
111
        for layer in self.loop_weights.values():
            layer.reset_parameters()
112
113

    def forward(self, g, inputs):
YJ-Zhao's avatar
YJ-Zhao committed
114
        """
115
116
        Parameters
        ----------
peizhou001's avatar
peizhou001 committed
117
        g : DGLGraph
118
119
120
121
122
123
124
125
126
127
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
128
129
130
131
        wdict = {
            rel_name: {"weight": self.weight[rel_name].weight.T}
            for rel_name in self.rel_names
        }
132

133
134
135
        inputs_dst = {
            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
        }
136
137
138
139

        hs = self.conv(g, inputs, mod_kwargs=wdict)

        def _apply(ntype, h):
YJ-Zhao's avatar
YJ-Zhao committed
140
            h = h + self.loop_weights[ntype](inputs_dst[ntype])
141
142
143
144
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)

145
146
        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

147
148

class EntityClassify(nn.Module):
YJ-Zhao's avatar
YJ-Zhao committed
149
    def __init__(self, g, in_dim, out_dim):
150
151
        super(EntityClassify, self).__init__()
        self.in_dim = in_dim
YJ-Zhao's avatar
YJ-Zhao committed
152
        self.h_dim = 64
153
154
155
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
YJ-Zhao's avatar
YJ-Zhao committed
156
        self.dropout = 0.5
157
158
159

        self.layers = nn.ModuleList()
        # i2h
160
161
162
163
164
165
166
167
168
169
        self.layers.append(
            RelGraphConvLayer(
                self.in_dim,
                self.h_dim,
                g.ntypes,
                self.rel_names,
                activation=F.relu,
                dropout=self.dropout,
            )
        )
YJ-Zhao's avatar
YJ-Zhao committed
170

171
        # h2o
172
173
174
175
176
177
178
179
180
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.out_dim,
                g.ntypes,
                self.rel_names,
                activation=None,
            )
        )
181
182
183
184
185
186
187
188
189
190

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, h, blocks):
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h)
        return h

191

192
193
194
195
196
197
198
class Logger(object):
    r"""
    This class was taken directly from the PyG implementation and can be found
    here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py

    This was done to ensure that performance was measured in precisely the same way
    """
199

YJ-Zhao's avatar
YJ-Zhao committed
200
    def __init__(self, runs):
201
202
203
204
205
206
207
208
209
210
211
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * th.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
212
213
214
215
216
            print(f"Run {run + 1:02d}:")
            print(f"Highest Train: {result[:, 0].max():.2f}")
            print(f"Highest Valid: {result[:, 1].max():.2f}")
            print(f"  Final Train: {result[argmax, 0]:.2f}")
            print(f"   Final Test: {result[argmax, 2]:.2f}")
217
218
219
220
221
222
223
224
225
226
227
228
229
        else:
            result = 100 * th.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = th.tensor(best_results)

230
            print(f"All runs:")
231
            r = best_result[:, 0]
232
            print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
233
            r = best_result[:, 1]
234
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
235
            r = best_result[:, 2]
236
            print(f"  Final Train: {r.mean():.2f} ± {r.std():.2f}")
237
            r = best_result[:, 3]
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}")


def train(
    g,
    model,
    node_embed,
    optimizer,
    train_loader,
    split_idx,
    labels,
    logger,
    device,
    run,
):
253
    print("start training...")
254
    category = "paper"
255

YJ-Zhao's avatar
YJ-Zhao committed
256
    for epoch in range(3):
257
        num_train = split_idx["train"][category].shape[0]
YJ-Zhao's avatar
YJ-Zhao committed
258
        pbar = tqdm(total=num_train)
259
        pbar.set_description(f"Epoch {epoch:02d}")
260
        model.train()
YJ-Zhao's avatar
YJ-Zhao committed
261

262
263
264
265
        total_loss = 0

        for input_nodes, seeds, blocks in train_loader:
            blocks = [blk.to(device) for blk in blocks]
266
267
268
            seeds = seeds[
                category
            ]  # we only predict the nodes with type "category"
269
            batch_size = seeds.shape[0]
270
271
            input_nodes_indexes = input_nodes["paper"].to(g.device)
            seeds = seeds.to(labels.device)
272
273
274

            emb = extract_embed(node_embed, input_nodes)
            # Add the batch's raw "paper" features
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
275
            emb.update({"paper": g.ndata["feat"]["paper"][input_nodes_indexes]})
YJ-Zhao's avatar
YJ-Zhao committed
276

277
            emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
278
279
            lbl = labels[seeds].to(device)

280
281
            optimizer.zero_grad()
            logits = model(emb, blocks)[category]
YJ-Zhao's avatar
YJ-Zhao committed
282

283
284
285
286
            y_hat = logits.log_softmax(dim=-1)
            loss = F.nll_loss(y_hat, lbl)
            loss.backward()
            optimizer.step()
YJ-Zhao's avatar
YJ-Zhao committed
287

288
289
            total_loss += loss.item() * batch_size
            pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
290

291
        pbar.close()
YJ-Zhao's avatar
YJ-Zhao committed
292
293
294
        loss = total_loss / num_train

        result = test(g, model, node_embed, labels, device, split_idx)
295
296
        logger.add_result(run, result)
        train_acc, valid_acc, test_acc = result
297
298
299
300
301
302
303
304
        print(
            f"Run: {run + 1:02d}, "
            f"Epoch: {epoch +1 :02d}, "
            f"Loss: {loss:.4f}, "
            f"Train: {100 * train_acc:.2f}%, "
            f"Valid: {100 * valid_acc:.2f}%, "
            f"Test: {100 * test_acc:.2f}%"
        )
YJ-Zhao's avatar
YJ-Zhao committed
305

306
307
    return logger

308

309
@th.no_grad()
YJ-Zhao's avatar
YJ-Zhao committed
310
def test(g, model, node_embed, y_true, device, split_idx):
311
    model.eval()
312
313
    category = "paper"
    evaluator = Evaluator(name="ogbn-mag")
314

YJ-Zhao's avatar
YJ-Zhao committed
315
316
    # 2 GNN layers
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
317
    loader = dgl.dataloading.DataLoader(
318
319
320
321
322
323
        g,
        {"paper": th.arange(g.num_nodes("paper"))},
        sampler,
        batch_size=16384,
        shuffle=False,
        num_workers=0,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
324
        device=device,
325
    )
YJ-Zhao's avatar
YJ-Zhao committed
326
327

    pbar = tqdm(total=y_true.size(0))
328
    pbar.set_description(f"Inference")
YJ-Zhao's avatar
YJ-Zhao committed
329

330
331
332
333
    y_hats = list()

    for input_nodes, seeds, blocks in loader:
        blocks = [blk.to(device) for blk in blocks]
334
335
336
        seeds = seeds[
            category
        ]  # we only predict the nodes with type "category"
337
        batch_size = seeds.shape[0]
338
        input_nodes_indexes = input_nodes["paper"].to(g.device)
339
340
341

        emb = extract_embed(node_embed, input_nodes)
        # Get the batch's raw "paper" features
342
        emb.update({"paper": g.ndata["feat"]["paper"][input_nodes_indexes]})
343
        emb = {k: e.to(device) for k, e in emb.items()}
YJ-Zhao's avatar
YJ-Zhao committed
344

345
346
347
        logits = model(emb, blocks)[category]
        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
        y_hats.append(y_hat.cpu())
YJ-Zhao's avatar
YJ-Zhao committed
348

349
        pbar.update(batch_size)
YJ-Zhao's avatar
YJ-Zhao committed
350

351
352
353
354
355
    pbar.close()

    y_pred = th.cat(y_hats, dim=0)
    y_true = th.unsqueeze(y_true, 1)

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    train_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["train"]["paper"]],
            "y_pred": y_pred[split_idx["train"]["paper"]],
        }
    )["acc"]
    valid_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["valid"]["paper"]],
            "y_pred": y_pred[split_idx["valid"]["paper"]],
        }
    )["acc"]
    test_acc = evaluator.eval(
        {
            "y_true": y_true[split_idx["test"]["paper"]],
            "y_pred": y_pred[split_idx["test"]["paper"]],
        }
    )["acc"]
374
375
376

    return train_acc, valid_acc, test_acc

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
377

378
379
380
def is_support_affinity(v_t):
    # dgl supports enable_cpu_affinity since 0.9.1
    return v_t >= "0.9.1"
381

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
382

383
def main(args):
384
    device = f"cuda:0" if th.cuda.is_available() else "cpu"
385

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
386
387
388
    g, labels, num_classes, split_idx, logger, train_loader = prepare_data(
        args, device
    )
YJ-Zhao's avatar
YJ-Zhao committed
389

390
    embed_layer = rel_graph_embed(g, 128).to(device)
YJ-Zhao's avatar
YJ-Zhao committed
391
    model = EntityClassify(g, 128, num_classes).to(device)
392

393
394
395
396
397
398
    print(
        f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}"
    )
    print(
        f"Number of model parameters: {sum(p.numel() for p in model.parameters())}"
    )
YJ-Zhao's avatar
YJ-Zhao committed
399
400

    for run in range(args.runs):
401
402
403
404
405
406
        try:
            embed_layer.reset_parameters()
            model.reset_parameters()
        except:
            # old pytorch version doesn't support reset_parameters() API
            pass
407
408

        # optimizer
409
410
411
        all_params = itertools.chain(
            model.parameters(), embed_layer.parameters()
        )
YJ-Zhao's avatar
YJ-Zhao committed
412
        optimizer = th.optim.Adam(all_params, lr=0.01)
413

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
414
415
416
417
418
        if (
            args.num_workers != 0
            and device == "cpu"
            and is_support_affinity(v_t)
        ):
419
420
            expected_max = int(psutil.cpu_count(logical=False))
            if args.num_workers >= expected_max:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
421
422
423
424
                print(
                    f"[ERROR] You specified num_workers are larger than physical cores, please set any number less than {expected_max}",
                    file=sys.stderr,
                )
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            with train_loader.enable_cpu_affinity():
                logger = train(
                    g,
                    model,
                    embed_layer,
                    optimizer,
                    train_loader,
                    split_idx,
                    labels,
                    logger,
                    device,
                    run,
                )
        else:
            logger = train(
                g,
                model,
                embed_layer,
                optimizer,
                train_loader,
                split_idx,
                labels,
                logger,
                device,
                run,
            )
451
        logger.print_statistics(run)
YJ-Zhao's avatar
YJ-Zhao committed
452

453
454
455
    print("Final performance: ")
    logger.print_statistics()

456
457
458
459

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument("--runs", type=int, default=10)
460
    parser.add_argument("--num_workers", type=int, default=0)
YJ-Zhao's avatar
YJ-Zhao committed
461
462
463

    args = parser.parse_args()

464
    main(args)