train_dist.py 16.2 KB
Newer Older
1
import os
2
3
4
5
6
7
8

os.environ["DGLBACKEND"] = "pytorch"
import argparse
import math
import socket
import time
from functools import wraps
9
from multiprocessing import Process
10

11
import numpy as np
12
13
14
15
16
import torch as th
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
17
import tqdm
18
from torch.utils.data import DataLoader
19
20
21
22

import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
23
24
25
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from dgl.data.utils import load_graphs
26
from dgl.distributed import DistDataLoader
27
28


29
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
30
31
32
    """
    Copys features and labels of a set of nodes onto GPU.
    """
33
34
35
36
    batch_inputs = (
        g.ndata["features"][input_nodes].to(device) if load_feat else None
    )
    batch_labels = g.ndata["labels"][seeds].to(device)
37
38
    return batch_inputs, batch_labels

39

40
class NeighborSampler(object):
41
    def __init__(self, g, fanouts, sample_neighbors, device, load_feat=True):
42
43
44
        self.g = g
        self.fanouts = fanouts
        self.sample_neighbors = sample_neighbors
45
        self.device = device
46
        self.load_feat = load_feat
47
48
49
50
51
52

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
53
54
55
            frontier = self.sample_neighbors(
                self.g, seeds, fanout, replace=True
            )
56
57
58
59
60
61
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
62
63
64

        input_nodes = blocks[0].srcdata[dgl.NID]
        seeds = blocks[-1].dstdata[dgl.NID]
65
66
67
        batch_inputs, batch_labels = load_subtensor(
            self.g, seeds, input_nodes, "cpu", self.load_feat
        )
68
        if self.load_feat:
69
70
            blocks[0].srcdata["features"] = batch_inputs
        blocks[-1].dstdata["labels"] = batch_labels
71
        return blocks
72

73

74
class DistSAGE(nn.Module):
75
76
77
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
78
79
80
81
82
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
83
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
84
        for i in range(1, n_layers - 1):
85
86
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
87
88
89
90
91
92
93
94
95
96
97
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
113
114
115
116
117
118
119
120
121
122
123
        nodes = dgl.distributed.node_split(
            np.arange(g.number_of_nodes()),
            g.get_partition_book(),
            force_even=True,
        )
        y = dgl.distributed.DistTensor(
            (g.number_of_nodes(), self.n_hidden),
            th.float32,
            "h",
            persistent=True,
        )
124
125
        for l, layer in enumerate(self.layers):
            if l == len(self.layers) - 1:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
                y = dgl.distributed.DistTensor(
                    (g.number_of_nodes(), self.n_classes),
                    th.float32,
                    "h_last",
                    persistent=True,
                )

            sampler = NeighborSampler(
                g, [-1], dgl.distributed.sample_neighbors, device
            )
            print(
                "|V|={}, eval batch size: {}".format(
                    g.number_of_nodes(), batch_size
                )
            )
141
            # Create PyTorch DataLoader for constructing blocks
142
            dataloader = DistDataLoader(
143
144
145
146
                dataset=nodes,
                batch_size=batch_size,
                collate_fn=sampler.sample_blocks,
                shuffle=False,
147
148
                drop_last=False,
            )
149
150

            for blocks in tqdm.tqdm(dataloader):
151
                block = blocks[0].to(device)
152
153
154
                input_nodes = block.srcdata[dgl.NID]
                output_nodes = block.dstdata[dgl.NID]
                h = x[input_nodes].to(device)
155
                h_dst = h[: block.number_of_dst_nodes()]
156
157
158
159
160
161
162
163
164
165
166
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

167

168
169
170
171
172
173
174
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
190
191
192
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
        pred[test_nid], labels[test_nid]
    )
193

194

195
196
def run(args, device, data):
    # Unpack data
197
    train_nid, val_nid, test_nid, in_feats, n_classes, g = data
198
    shuffle = True
199
    # Create sampler
200
201
202
203
204
205
    sampler = NeighborSampler(
        g,
        [int(fanout) for fanout in args.fan_out.split(",")],
        dgl.distributed.sample_neighbors,
        device,
    )
206

207
208
    # Create DataLoader for constructing blocks
    dataloader = DistDataLoader(
209
210
211
        dataset=train_nid.numpy(),
        batch_size=args.batch_size,
        collate_fn=sampler.sample_blocks,
212
        shuffle=shuffle,
213
214
        drop_last=False,
    )
215
216

    # Define model and optimizer
217
218
219
220
221
222
223
224
    model = DistSAGE(
        in_feats,
        args.num_hidden,
        n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
    )
225
    model = model.to(device)
226
    if not args.standalone:
227
228
229
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
230
231
232
            model = th.nn.parallel.DistributedDataParallel(
                model, device_ids=[device], output_device=device
            )
233
234
235
236
237
238
239
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    iter_tput = []
    epoch = 0
240
    for epoch in range(args.num_epochs):
241
242
243
244
245
246
247
248
249
250
251
252
        tic = time.time()

        sample_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        num_seeds = 0
        num_inputs = 0
        start = time.time()
        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        step_time = []
253
254
255
256
257
258
259
260

        with model.join():
            for step, blocks in enumerate(dataloader):
                tic_step = time.time()
                sample_time += tic_step - start

                # The nodes for input lies at the LHS side of the first block.
                # The nodes for output lies at the RHS side of the last block.
261
262
                batch_inputs = blocks[0].srcdata["features"]
                batch_labels = blocks[-1].dstdata["labels"]
263
264
265
266
267
268
269
270
                batch_labels = batch_labels.long()

                num_seeds += len(blocks[-1].dstdata[dgl.NID])
                num_inputs += len(blocks[0].srcdata[dgl.NID])
                blocks = [block.to(device) for block in blocks]
                batch_labels = batch_labels.to(device)
                # Compute loss and prediction
                start = time.time()
271
                # print(g.rank(), blocks[0].device, model.module.layers[0].fc_neigh.weight.device, dev_id)
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, batch_labels)
                forward_end = time.time()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_time += forward_end - start
                backward_time += compute_end - forward_end

                optimizer.step()
                update_time += time.time() - compute_end

                step_t = time.time() - tic_step
                step_time.append(step_t)
                iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
                if step % args.log_every == 0:
                    acc = compute_acc(batch_pred, batch_labels)
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                    gpu_mem_alloc = (
                        th.cuda.max_memory_allocated() / 1000000
                        if th.cuda.is_available()
                        else 0
                    )
                    print(
                        "Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s".format(
                            g.rank(),
                            epoch,
                            step,
                            loss.item(),
                            acc.item(),
                            np.mean(iter_tput[3:]),
                            gpu_mem_alloc,
                            np.sum(step_time[-args.log_every :]),
                        )
                    )
306
                start = time.time()
307
308

        toc = time.time()
309
310
311
312
313
314
315
316
317
318
319
320
        print(
            "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}".format(
                g.rank(),
                toc - tic,
                sample_time,
                forward_time,
                backward_time,
                update_time,
                num_seeds,
                num_inputs,
            )
        )
321
322
        epoch += 1

323
324
        if epoch % args.eval_every == 0 and epoch != 0:
            start = time.time()
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            val_acc, test_acc = evaluate(
                model.module,
                g,
                g.ndata["features"],
                g.ndata["labels"],
                val_nid,
                test_nid,
                args.batch_size_eval,
                device,
            )
            print(
                "Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
                    g.rank(), val_acc, test_acc, time.time() - start
                )
            )

341
342

def main(args):
343
    print(socket.gethostname(), "Initializing DGL dist")
344
    dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
345
    if not args.standalone:
346
        print(socket.gethostname(), "Initializing DGL process group")
347
        th.distributed.init_process_group(backend=args.backend)
348
    print(socket.gethostname(), "Initializing DistGraph")
349
    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
350
    print(socket.gethostname(), "rank:", g.rank())
351

352
    pb = g.get_partition_book()
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    if "trainer_id" in g.ndata:
        train_nid = dgl.distributed.node_split(
            g.ndata["train_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
        val_nid = dgl.distributed.node_split(
            g.ndata["val_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
        test_nid = dgl.distributed.node_split(
            g.ndata["test_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
372
    else:
373
374
375
376
377
378
379
380
381
        train_nid = dgl.distributed.node_split(
            g.ndata["train_mask"], pb, force_even=True
        )
        val_nid = dgl.distributed.node_split(
            g.ndata["val_mask"], pb, force_even=True
        )
        test_nid = dgl.distributed.node_split(
            g.ndata["test_mask"], pb, force_even=True
        )
382
    local_nid = pb.partid2nids(pb.partid).detach().numpy()
383
384
385
386
387
388
389
390
391
392
393
    print(
        "part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format(
            g.rank(),
            len(train_nid),
            len(np.intersect1d(train_nid.numpy(), local_nid)),
            len(val_nid),
            len(np.intersect1d(val_nid.numpy(), local_nid)),
            len(test_nid),
            len(np.intersect1d(test_nid.numpy(), local_nid)),
        )
    )
394
    del local_nid
395
    if args.num_gpus == -1:
396
        device = th.device("cpu")
397
    else:
398
        dev_id = g.rank() % args.num_gpus
399
        device = th.device("cuda:" + str(dev_id))
400
401
    n_classes = args.n_classes
    if n_classes == -1:
402
        labels = g.ndata["labels"][np.arange(g.number_of_nodes())]
403
404
        n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
        del labels
405
    print("#labels:", n_classes)
406
407

    # Pack data
408
    in_feats = g.ndata["features"].shape[1]
409
    data = train_nid, val_nid, test_nid, in_feats, n_classes, g
410
411
412
    run(args, device, data)
    print("parent ends")

413
414
415

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
416
    register_data_args(parser)
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    parser.add_argument("--graph_name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip_config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--part_config", type=str, help="The path to the partition config file"
    )
    parser.add_argument("--num_clients", type=int, help="The number of clients")
    parser.add_argument(
        "--n_classes",
        type=int,
        default=-1,
        help="The number of classes. If not specified, this"
        " value will be calculated via scaning all the labels"
        " in the dataset which probably causes memory burst.",
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="gloo",
        help="pytorch distributed backend",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_hidden", type=int, default=16)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--fan_out", type=str, default="10,25")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--batch_size_eval", type=int, default=100000)
    parser.add_argument("--log_every", type=int, default=20)
    parser.add_argument("--eval_every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.003)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
    parser.add_argument(
        "--pad-data",
        default=False,
        action="store_true",
        help="Pad train nid to the same length across machine, to ensure num of batches to be the same.",
    )
    parser.add_argument(
        "--net_type",
        type=str,
        default="socket",
        help="backend net type, 'socket' or 'tensorpipe'",
    )
474
475
476
    args = parser.parse_args()

    print(args)
477
    main(args)