entity_classify_dist.py 28 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Code: https://github.com/tkipf/relational-gcn
Difference compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
"""
import argparse
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
import gc, os
11
12
13
import itertools
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14
15
16
17
18
19
20
import numpy as np

os.environ["DGLBACKEND"] = "pytorch"

from functools import partial

import dgl
21
import torch as th
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
22
import torch.multiprocessing as mp
23
24
25
26
import torch.nn as nn
import torch.nn.functional as F

import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
27
28
from dgl import DGLGraph, nn as dglnn
from dgl.distributed import DistDataLoader
29
30

from ogb.nodeproppred import DglNodePropPredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31
32
33
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

class RelGraphConvLayer(nn.Module):
    r"""Relational graph convolution layer.
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    rel_names : list[str]
        Relation names.
    num_bases : int, optional
        Number of bases. If is none, use number of relations. Default: None.
    weight : bool, optional
        True if a linear layer is applied after message passing. Default: True
    bias : bool, optional
        True if bias is added. Default: True
    activation : callable, optional
        Activation function. Default: None
    self_loop : bool, optional
        True to include self loop message. Default: False
    dropout : float, optional
        Dropout rate. Default: 0.0
    """
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    def __init__(
        self,
        in_feat,
        out_feat,
        rel_names,
        num_bases,
        *,
        weight=True,
        bias=True,
        activation=None,
        self_loop=False,
        dropout=0.0
    ):
73
74
75
76
77
78
79
80
81
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.rel_names = rel_names
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
84
85
86
        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
87
                for rel in rel_names
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
89
            }
        )
90
91
92
93
94

        self.use_weight = weight
        self.use_basis = num_bases < len(self.rel_names) and weight
        if self.use_weight:
            if self.use_basis:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
95
96
97
                self.basis = dglnn.WeightBasis(
                    (in_feat, out_feat), num_bases, len(self.rel_names)
                )
98
            else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
100
101
102
103
104
                self.weight = nn.Parameter(
                    th.Tensor(len(self.rel_names), in_feat, out_feat)
                )
                nn.init.xavier_uniform_(
                    self.weight, gain=nn.init.calculate_gain("relu")
                )
105
106
107
108
109
110
111
112
113

        # bias
        if bias:
            self.h_bias = nn.Parameter(th.Tensor(out_feat))
            nn.init.zeros_(self.h_bias)

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
114
115
116
            nn.init.xavier_uniform_(
                self.loop_weight, gain=nn.init.calculate_gain("relu")
            )
117
118
119
120
121
122
123

        self.dropout = nn.Dropout(dropout)

    def forward(self, g, inputs):
        """Forward computation
        Parameters
        ----------
peizhou001's avatar
peizhou001 committed
124
        g : DGLGraph
125
126
127
128
129
130
131
132
133
134
135
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.
        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
        if self.use_weight:
            weight = self.basis() if self.use_basis else self.weight
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
137
138
139
            wdict = {
                self.rel_names[i]: {"weight": w.squeeze(0)}
                for i, w in enumerate(th.split(weight, 1, dim=0))
            }
140
141
142
143
144
        else:
            wdict = {}

        if g.is_block:
            inputs_src = inputs
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
145
146
147
            inputs_dst = {
                k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
            }
148
149
150
151
152
153
154
155
156
157
158
159
160
        else:
            inputs_src = inputs_dst = inputs

        hs = self.conv(g, inputs, mod_kwargs=wdict)

        def _apply(ntype, h):
            if self.self_loop:
                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
            if self.bias:
                h = h + self.h_bias
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
161
162
163

        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

164

165
class EntityClassify(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
166
    """Entity classification class for RGCN
167
168
169
170
171
172
173
174
175
176
    Parameters
    ----------
    device : int
        Device to run the layer.
    num_nodes : int
        Number of nodes.
    h_dim : int
        Hidden dim size.
    out_dim : int
        Output dim size.
177
178
    rel_names : list of str
        A list of relation names.
179
180
181
182
183
184
185
186
187
    num_bases : int
        Number of bases. If is none, use number of relations.
    num_hidden_layers : int
        Number of hidden RelGraphConv Layer
    dropout : float
        Dropout
    use_self_loop : bool
        Use self loop if True, default False.
    """
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
188
189
190
191
192
193
194
195
196
197
198
199
200

    def __init__(
        self,
        device,
        h_dim,
        out_dim,
        rel_names,
        num_bases=None,
        num_hidden_layers=1,
        dropout=0,
        use_self_loop=False,
        layer_norm=False,
    ):
201
202
203
204
205
206
207
208
209
210
211
212
        super(EntityClassify, self).__init__()
        self.device = device
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_bases = None if num_bases < 0 else num_bases
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.use_self_loop = use_self_loop
        self.layer_norm = layer_norm

        self.layers = nn.ModuleList()
        # i2h
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
213
214
215
216
217
218
219
220
221
222
223
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.h_dim,
                rel_names,
                self.num_bases,
                activation=F.relu,
                self_loop=self.use_self_loop,
                dropout=self.dropout,
            )
        )
224
225
        # h2h
        for idx in range(self.num_hidden_layers):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
226
227
228
229
230
231
232
233
234
235
236
            self.layers.append(
                RelGraphConvLayer(
                    self.h_dim,
                    self.h_dim,
                    rel_names,
                    self.num_bases,
                    activation=F.relu,
                    self_loop=self.use_self_loop,
                    dropout=self.dropout,
                )
            )
237
        # h2o
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
238
239
240
241
242
243
244
245
246
247
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.out_dim,
                rel_names,
                self.num_bases,
                activation=None,
                self_loop=self.use_self_loop,
            )
        )
248
249
250
251
252
253
254
255

    def forward(self, blocks, feats, norm=None):
        if blocks is None:
            # full graph training
            blocks = [self.g] * len(self.layers)
        h = feats
        for layer, block in zip(self.layers, blocks):
            block = block.to(self.device)
256
            h = layer(block, h)
257
258
        return h

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
259

260
261
262
263
264
def init_emb(shape, dtype):
    arr = th.zeros(shape, dtype=dtype)
    nn.init.uniform_(arr, -1.0, 1.0)
    return arr

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
class DistEmbedLayer(nn.Module):
    r"""Embedding layer for featureless heterograph.
    Parameters
    ----------
    dev_id : int
        Device to run the layer.
    g : DistGraph
        training graph
    embed_size : int
        Output embed size
    sparse_emb: bool
        Whether to use sparse embedding
        Default: False
    dgl_sparse_emb: bool
        Whether to use DGL sparse embedding
        Default: False
    embed_name : str, optional
        Embed name
    """
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
285
286
287
288
289
290
291
292
293
294
295

    def __init__(
        self,
        dev_id,
        g,
        embed_size,
        sparse_emb=False,
        dgl_sparse_emb=False,
        feat_name="feat",
        embed_name="node_emb",
    ):
296
297
298
299
        super(DistEmbedLayer, self).__init__()
        self.dev_id = dev_id
        self.embed_size = embed_size
        self.embed_name = embed_name
300
        self.feat_name = feat_name
301
        self.sparse_emb = sparse_emb
302
        self.g = g
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
303
        self.ntype_id_map = {g.get_ntype_id(ntype): ntype for ntype in g.ntypes}
304
305
306
307

        self.node_projs = nn.ModuleDict()
        for ntype in g.ntypes:
            if feat_name in g.nodes[ntype].data:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
308
309
310
                self.node_projs[ntype] = nn.Linear(
                    g.nodes[ntype].data[feat_name].shape[1], embed_size
                )
311
                nn.init.xavier_uniform_(self.node_projs[ntype].weight)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
312
                print("node {} has data {}".format(ntype, feat_name))
313
314
        if sparse_emb:
            if dgl_sparse_emb:
315
316
317
318
319
                self.node_embeds = {}
                for ntype in g.ntypes:
                    # We only create embeddings for nodes without node features.
                    if feat_name not in g.nodes[ntype].data:
                        part_policy = g.get_node_partition_policy(ntype)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
320
                        self.node_embeds[ntype] = dgl.distributed.DistEmbedding(
321
                            g.num_nodes(ntype),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
322
323
324
325
326
                            self.embed_size,
                            embed_name + "_" + ntype,
                            init_emb,
                            part_policy,
                        )
327
            else:
328
329
330
331
                self.node_embeds = nn.ModuleDict()
                for ntype in g.ntypes:
                    # We only create embeddings for nodes without node features.
                    if feat_name not in g.nodes[ntype].data:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
332
                        self.node_embeds[ntype] = th.nn.Embedding(
333
                            g.num_nodes(ntype),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
334
335
336
337
338
339
                            self.embed_size,
                            sparse=self.sparse_emb,
                        )
                        nn.init.uniform_(
                            self.node_embeds[ntype].weight, -1.0, 1.0
                        )
340
        else:
341
342
343
344
            self.node_embeds = nn.ModuleDict()
            for ntype in g.ntypes:
                # We only create embeddings for nodes without node features.
                if feat_name not in g.nodes[ntype].data:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
345
                    self.node_embeds[ntype] = th.nn.Embedding(
346
                        g.num_nodes(ntype), self.embed_size
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
347
                    )
348
349
                    nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0)

350
    def forward(self, node_ids):
351
352
353
        """Forward computation
        Parameters
        ----------
354
        node_ids : dict of Tensor
355
356
357
358
359
360
            node ids to generate embedding for.
        Returns
        -------
        tensor
            embeddings as the input of the next layer
        """
361
362
        embeds = {}
        for ntype in node_ids:
363
            if self.feat_name in self.g.nodes[ntype].data:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
364
365
366
367
368
                embeds[ntype] = self.node_projs[ntype](
                    self.g.nodes[ntype]
                    .data[self.feat_name][node_ids[ntype]]
                    .to(self.dev_id)
                )
369
            else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
370
371
372
                embeds[ntype] = self.node_embeds[ntype](node_ids[ntype]).to(
                    self.dev_id
                )
373
374
        return embeds

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
375

376
377
378
379
380
381
382
def compute_acc(results, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (results == labels).float().sum() / len(results)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
383
384
385
386
387
388
389
390
391
392
393

def evaluate(
    g,
    model,
    embed_layer,
    labels,
    eval_loader,
    test_loader,
    all_val_nid,
    all_test_nid,
):
394
395
396
397
398
    model.eval()
    embed_layer.eval()
    eval_logits = []
    eval_seeds = []

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
399
400
401
    global_results = dgl.distributed.DistTensor(
        labels.shape, th.long, "results", persistent=True
    )
402
403

    with th.no_grad():
404
        th.cuda.empty_cache()
405
        for sample_data in tqdm.tqdm(eval_loader):
406
            input_nodes, seeds, blocks = sample_data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
407
            seeds = seeds["paper"]
408
            feats = embed_layer(input_nodes)
409
            logits = model(blocks, feats)
410
            assert len(logits) == 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
411
            logits = logits["paper"]
412
            eval_logits.append(logits.cpu().detach())
413
            assert np.all(seeds.numpy() < g.num_nodes("paper"))
414
415
416
417
418
419
420
421
            eval_seeds.append(seeds.cpu().detach())
    eval_logits = th.cat(eval_logits)
    eval_seeds = th.cat(eval_seeds)
    global_results[eval_seeds] = eval_logits.argmax(dim=1)

    test_logits = []
    test_seeds = []
    with th.no_grad():
422
        th.cuda.empty_cache()
423
        for sample_data in tqdm.tqdm(test_loader):
424
            input_nodes, seeds, blocks = sample_data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
425
            seeds = seeds["paper"]
426
            feats = embed_layer(input_nodes)
427
            logits = model(blocks, feats)
428
            assert len(logits) == 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
429
            logits = logits["paper"]
430
            test_logits.append(logits.cpu().detach())
431
            assert np.all(seeds.numpy() < g.num_nodes("paper"))
432
433
434
435
436
437
438
            test_seeds.append(seeds.cpu().detach())
    test_logits = th.cat(test_logits)
    test_seeds = th.cat(test_seeds)
    global_results[test_seeds] = test_logits.argmax(dim=1)

    g.barrier()
    if g.rank() == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
439
440
441
        return compute_acc(
            global_results[all_val_nid], labels[all_val_nid]
        ), compute_acc(global_results[all_test_nid], labels[all_test_nid])
442
443
444
445
    else:
        return -1, -1


Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def run(args, device, data):
    (
        g,
        num_classes,
        train_nid,
        val_nid,
        test_nid,
        labels,
        all_val_nid,
        all_test_nid,
    ) = data

    fanouts = [int(fanout) for fanout in args.fanout.split(",")]
    val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(",")]
460
461

    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
462
    dataloader = dgl.dataloading.DistNodeDataLoader(
463
        g,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
464
        {"paper": train_nid},
465
        sampler,
466
467
        batch_size=args.batch_size,
        shuffle=True,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
468
469
        drop_last=False,
    )
470

471
    valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)
472
    valid_dataloader = dgl.dataloading.DistNodeDataLoader(
473
        g,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
474
        {"paper": val_nid},
475
        valid_sampler,
476
477
        batch_size=args.batch_size,
        shuffle=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
478
479
        drop_last=False,
    )
480

481
    test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)
482
    test_dataloader = dgl.dataloading.DistNodeDataLoader(
483
        g,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
484
        {"paper": test_nid},
485
        test_sampler,
486
        batch_size=args.eval_batch_size,
487
        shuffle=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        drop_last=False,
    )

    embed_layer = DistEmbedLayer(
        device,
        g,
        args.n_hidden,
        sparse_emb=args.sparse_embedding,
        dgl_sparse_emb=args.dgl_sparse,
        feat_name="feat",
    )

    model = EntityClassify(
        device,
        args.n_hidden,
        num_classes,
        g.etypes,
        num_bases=args.n_bases,
        num_hidden_layers=args.n_layers - 2,
        dropout=args.dropout,
        use_self_loop=args.use_self_loop,
        layer_norm=args.layer_norm,
    )
511
    model = model.to(device)
512

513
    if not args.standalone:
514
515
516
517
518
519
520
521
        if args.num_gpus == -1:
            model = DistributedDataParallel(model)
            # If there are dense parameters in the embedding layer
            # or we use Pytorch saprse embeddings.
            if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:
                embed_layer = DistributedDataParallel(embed_layer)
        else:
            dev_id = g.rank() % args.num_gpus
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
522
523
524
            model = DistributedDataParallel(
                model, device_ids=[dev_id], output_device=dev_id
            )
525
526
527
528
            # If there are dense parameters in the embedding layer
            # or we use Pytorch saprse embeddings.
            if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:
                embed_layer = embed_layer.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
529
530
531
                embed_layer = DistributedDataParallel(
                    embed_layer, device_ids=[dev_id], output_device=dev_id
                )
532
533

    if args.sparse_embedding:
534
        if args.dgl_sparse and args.standalone:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
535
536
537
538
539
540
            emb_optimizer = dgl.distributed.optim.SparseAdam(
                list(embed_layer.node_embeds.values()), lr=args.sparse_lr
            )
            print(
                "optimize DGL sparse embedding:", embed_layer.node_embeds.keys()
            )
541
        elif args.dgl_sparse:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
542
543
544
545
546
547
548
            emb_optimizer = dgl.distributed.optim.SparseAdam(
                list(embed_layer.module.node_embeds.values()), lr=args.sparse_lr
            )
            print(
                "optimize DGL sparse embedding:",
                embed_layer.module.node_embeds.keys(),
            )
549
        elif args.standalone:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
550
551
552
553
            emb_optimizer = th.optim.SparseAdam(
                list(embed_layer.node_embeds.parameters()), lr=args.sparse_lr
            )
            print("optimize Pytorch sparse embedding:", embed_layer.node_embeds)
554
        else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
555
556
557
558
559
560
561
562
            emb_optimizer = th.optim.SparseAdam(
                list(embed_layer.module.node_embeds.parameters()),
                lr=args.sparse_lr,
            )
            print(
                "optimize Pytorch sparse embedding:",
                embed_layer.module.node_embeds,
            )
563

564
        dense_params = list(model.parameters())
565
566
        if args.standalone:
            dense_params += list(embed_layer.node_projs.parameters())
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
567
            print("optimize dense projection:", embed_layer.node_projs)
568
569
        else:
            dense_params += list(embed_layer.module.node_projs.parameters())
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
570
571
572
573
            print("optimize dense projection:", embed_layer.module.node_projs)
        optimizer = th.optim.Adam(
            dense_params, lr=args.lr, weight_decay=args.l2norm
        )
574
575
    else:
        all_params = list(model.parameters()) + list(embed_layer.parameters())
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
576
577
578
        optimizer = th.optim.Adam(
            all_params, lr=args.lr, weight_decay=args.l2norm
        )
579
580
581
582
583
584
585
586
587
588
589
590

    # training loop
    print("start training...")
    for epoch in range(args.n_epochs):
        tic = time.time()

        sample_time = 0
        copy_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        number_train = 0
591
        number_input = 0
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

        step_time = []
        iter_t = []
        sample_t = []
        feat_copy_t = []
        forward_t = []
        backward_t = []
        update_t = []
        iter_tput = []

        start = time.time()
        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        step_time = []
        for step, sample_data in enumerate(dataloader):
607
            input_nodes, seeds, blocks = sample_data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
608
            seeds = seeds["paper"]
609
            number_train += seeds.shape[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
610
611
612
            number_input += np.sum(
                [blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes]
            )
613
614
615
616
            tic_step = time.time()
            sample_time += tic_step - start
            sample_t.append(tic_step - start)

617
            feats = embed_layer(input_nodes)
618
            label = labels[seeds].to(device)
619
620
621
622
623
            copy_time = time.time()
            feat_copy_t.append(copy_time - tic_step)

            # forward
            logits = model(blocks, feats)
624
            assert len(logits) == 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
625
            logits = logits["paper"]
626
627
628
629
630
            loss = F.cross_entropy(logits, label)
            forward_end = time.time()

            # backward
            optimizer.zero_grad()
631
            if args.sparse_embedding:
632
633
634
635
636
637
                emb_optimizer.zero_grad()
            loss.backward()
            compute_end = time.time()
            forward_t.append(forward_end - copy_time)
            backward_t.append(compute_end - forward_end)

638
639
640
641
            # Update model parameters
            optimizer.step()
            if args.sparse_embedding:
                emb_optimizer.step()
642
643
644
645
            update_t.append(time.time() - compute_end)
            step_t = time.time() - start
            step_time.append(step_t)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
646
647
648
            train_acc = th.sum(logits.argmax(dim=1) == label).item() / len(
                seeds
            )
649

650
            if step % args.log_every == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                print(
                    "[{}] Epoch {:05d} | Step {:05d} | Train acc {:.4f} | Loss {:.4f} | time {:.3f} s"
                    "| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}".format(
                        g.rank(),
                        epoch,
                        step,
                        train_acc,
                        loss.item(),
                        np.sum(step_time[-args.log_every :]),
                        np.sum(sample_t[-args.log_every :]),
                        np.sum(feat_copy_t[-args.log_every :]),
                        np.sum(forward_t[-args.log_every :]),
                        np.sum(backward_t[-args.log_every :]),
                        np.sum(update_t[-args.log_every :]),
                    )
                )
667
668
            start = time.time()

669
        gc.collect()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
670
671
672
673
674
675
676
677
678
679
680
681
682
        print(
            "[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #train: {}, #input: {}".format(
                g.rank(),
                np.sum(step_time),
                np.sum(sample_t),
                np.sum(feat_copy_t),
                np.sum(forward_t),
                np.sum(backward_t),
                np.sum(update_t),
                number_train,
                number_input,
            )
        )
683
684
685
686
        epoch += 1

        start = time.time()
        g.barrier()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
687
688
689
690
691
692
693
694
695
696
        val_acc, test_acc = evaluate(
            g,
            model,
            embed_layer,
            labels,
            valid_dataloader,
            test_dataloader,
            all_val_nid,
            all_test_nid,
        )
697
        if val_acc >= 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
698
699
700
701
702
703
            print(
                "Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
                    val_acc, test_acc, time.time() - start
                )
            )

704
705

def main(args):
706
    dgl.distributed.initialize(args.ip_config)
707
    if not args.standalone:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
708
        th.distributed.init_process_group(backend="gloo")
709
710

    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.conf_path)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
711
    print("rank:", g.rank())
712
713

    pb = g.get_partition_book()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    if "trainer_id" in g.nodes["paper"].data:
        train_nid = dgl.distributed.node_split(
            g.nodes["paper"].data["train_mask"],
            pb,
            ntype="paper",
            force_even=True,
            node_trainer_ids=g.nodes["paper"].data["trainer_id"],
        )
        val_nid = dgl.distributed.node_split(
            g.nodes["paper"].data["val_mask"],
            pb,
            ntype="paper",
            force_even=True,
            node_trainer_ids=g.nodes["paper"].data["trainer_id"],
        )
        test_nid = dgl.distributed.node_split(
            g.nodes["paper"].data["test_mask"],
            pb,
            ntype="paper",
            force_even=True,
            node_trainer_ids=g.nodes["paper"].data["trainer_id"],
        )
736
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        train_nid = dgl.distributed.node_split(
            g.nodes["paper"].data["train_mask"],
            pb,
            ntype="paper",
            force_even=True,
        )
        val_nid = dgl.distributed.node_split(
            g.nodes["paper"].data["val_mask"],
            pb,
            ntype="paper",
            force_even=True,
        )
        test_nid = dgl.distributed.node_split(
            g.nodes["paper"].data["test_mask"],
            pb,
            ntype="paper",
            force_even=True,
        )
    local_nid = pb.partid2nids(pb.partid, "paper").detach().numpy()
    print(
        "part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format(
            g.rank(),
            len(train_nid),
            len(np.intersect1d(train_nid.numpy(), local_nid)),
            len(val_nid),
            len(np.intersect1d(val_nid.numpy(), local_nid)),
            len(test_nid),
            len(np.intersect1d(test_nid.numpy(), local_nid)),
        )
    )
767
    if args.num_gpus == -1:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
768
        device = th.device("cpu")
769
    else:
770
        dev_id = g.rank() % args.num_gpus
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
771
        device = th.device("cuda:" + str(dev_id))
772
    labels = g.nodes["paper"].data["labels"][np.arange(g.num_nodes("paper"))]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
773
774
    all_val_nid = th.LongTensor(
        np.nonzero(
775
            g.nodes["paper"].data["val_mask"][np.arange(g.num_nodes("paper"))]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
776
777
778
779
        )
    ).squeeze()
    all_test_nid = th.LongTensor(
        np.nonzero(
780
            g.nodes["paper"].data["test_mask"][np.arange(g.num_nodes("paper"))]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
781
782
        )
    ).squeeze()
783
    n_classes = len(th.unique(labels[labels >= 0]))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
    print("#classes:", n_classes)

    run(
        args,
        device,
        (
            g,
            n_classes,
            train_nid,
            val_nid,
            test_nid,
            labels,
            all_val_nid,
            all_test_nid,
        ),
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
804
    # distributed training related
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
805
806
807
808
809
810
811
812
    parser.add_argument("--graph-name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip-config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--conf-path", type=str, help="The path to the partition config file"
    )
813
814

    # rgcn related
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument(
        "--dropout", type=float, default=0, help="dropout probability"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden units"
    )
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--sparse-lr", type=float, default=1e-2, help="sparse lr rate"
    )
    parser.add_argument(
        "--n-bases",
        type=int,
        default=-1,
        help="number of filter weight matrices, default: -1 [use all]",
    )
    parser.add_argument(
        "--n-layers", type=int, default=2, help="number of propagation rounds"
    )
    parser.add_argument(
        "-e",
        "--n-epochs",
        type=int,
        default=50,
        help="number of training epochs",
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="dataset to use"
    )
    parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
    parser.add_argument(
        "--relabel",
        default=False,
        action="store_true",
        help="remove untouched nodes and relabel",
    )
    parser.add_argument(
        "--fanout",
        type=str,
        default="4, 4",
        help="Fan-out of neighbor sampling.",
    )
    parser.add_argument(
        "--validation-fanout",
        type=str,
        default=None,
        help="Fan-out of neighbor sampling during validation.",
    )
    parser.add_argument(
        "--use-self-loop",
        default=False,
        action="store_true",
        help="include self feature as a special relation",
    )
    parser.add_argument(
        "--batch-size", type=int, default=100, help="Mini-batch size. "
    )
    parser.add_argument(
        "--eval-batch-size", type=int, default=128, help="Mini-batch size. "
    )
    parser.add_argument("--log-every", type=int, default=20)
    parser.add_argument(
        "--low-mem",
        default=False,
        action="store_true",
        help="Whether use low mem RelGraphCov",
    )
    parser.add_argument(
        "--sparse-embedding",
        action="store_true",
        help="Use sparse embedding for node embeddings.",
    )
    parser.add_argument(
        "--dgl-sparse",
        action="store_true",
        help="Whether to use DGL sparse embedding",
    )
    parser.add_argument(
        "--layer-norm",
        default=False,
        action="store_true",
        help="Use layer norm",
    )
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
910
911
912
913
914
915
916
    args = parser.parse_args()

    # if validation_fanout is None, set it with args.fanout
    if args.validation_fanout is None:
        args.validation_fanout = args.fanout
    print(args)
    main(args)