train_dist.py 14.2 KB
Newer Older
1
2
3
import argparse
import socket
import time
4
from contextlib import contextmanager
5

6
import numpy as np
7
8
9
10
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
11
12
13
14
import tqdm
import dgl
import dgl.nn.pytorch as dglnn

15
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
16
17
18
    """
    Copys features and labels of a set of nodes onto GPU.
    """
19
20
21
22
    batch_inputs = (
        g.ndata["features"][input_nodes].to(device) if load_feat else None
    )
    batch_labels = g.ndata["labels"][seeds].to(device)
23
24
    return batch_inputs, batch_labels

25

26
class DistSAGE(nn.Module):
27
28
29
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
30
31
32
33
34
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
35
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
36
        for i in range(1, n_layers - 1):
37
38
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
39
40
41
42
43
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
44
        for i, (layer, block) in enumerate(zip(self.layers, blocks)):
45
            h = layer(block, h)
46
            if i != len(self.layers) - 1:
47
48
49
                h = self.activation(h)
                h = self.dropout(h)
        return h
50
51
52

    def inference(self, g, x, batch_size, device):
        """
53
54
55
        Inference with the GraphSAGE model on full neighbors (i.e. without
        neighbor sampling).

56
57
58
        g : the entire graph.
        x : the input of entire node set.

59
        Distributed layer-wise inference.
60
        """
61
62
63
64
65
        # During inference with sampling, multi-layer blocks are very
        # inefficient because lots of computations in the first few layers
        # are repeated. Therefore, we compute the representation of all nodes
        # layer by layer.  The nodes on each layer are of course splitted in
        # batches.
66
        # TODO: can we standardize this?
67
        nodes = dgl.distributed.node_split(
68
            np.arange(g.num_nodes()),
69
70
71
72
            g.get_partition_book(),
            force_even=True,
        )
        y = dgl.distributed.DistTensor(
73
            (g.num_nodes(), self.n_hidden),
74
75
76
77
            th.float32,
            "h",
            persistent=True,
        )
78
79
        for i, layer in enumerate(self.layers):
            if i == len(self.layers) - 1:
80
                y = dgl.distributed.DistTensor(
81
                    (g.num_nodes(), self.n_classes),
82
83
84
85
86
                    th.float32,
                    "h_last",
                    persistent=True,
                )
            print(
87
                f"|V|={g.num_nodes()}, eval batch size: {batch_size}"
88
            )
89
90
91
92
93
94

            sampler = dgl.dataloading.NeighborSampler([-1])
            dataloader = dgl.dataloading.DistNodeDataLoader(
                g,
                nodes,
                sampler,
95
96
                batch_size=batch_size,
                shuffle=False,
97
98
                drop_last=False,
            )
99

100
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
101
                block = blocks[0].to(device)
102
                h = x[input_nodes].to(device)
103
                h_dst = h[: block.number_of_dst_nodes()]
104
                h = layer(block, (h, h_dst))
105
                if i != len(self.layers) - 1:
106
107
108
109
110
111
112
113
114
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

115
116
117
118
119
    @contextmanager
    def join(self):
        """dummy join for standalone"""
        yield

120

121
122
123
124
125
126
127
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
143
144
145
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
        pred[test_nid], labels[test_nid]
    )
146

147

148
149
def run(args, device, data):
    # Unpack data
150
    train_nid, val_nid, test_nid, in_feats, n_classes, g = data
151
    shuffle = True
152
153
154
    # prefetch_node_feats/prefetch_labels are not supported for DistGraph yet.
    sampler = dgl.dataloading.NeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(",")]
155
    )
156
157
158
159
    dataloader = dgl.dataloading.DistNodeDataLoader(
        g,
        train_nid,
        sampler,
160
        batch_size=args.batch_size,
161
        shuffle=shuffle,
162
163
        drop_last=False,
    )
164
    # Define model and optimizer
165
166
167
168
169
170
171
172
    model = DistSAGE(
        in_feats,
        args.num_hidden,
        n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
    )
173
    model = model.to(device)
174
    if not args.standalone:
175
176
177
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
178
179
180
            model = th.nn.parallel.DistributedDataParallel(
                model, device_ids=[device], output_device=device
            )
181
182
183
184
185
186
187
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    iter_tput = []
    epoch = 0
188
    for epoch in range(args.num_epochs):
189
190
191
192
193
194
195
196
197
        tic = time.time()

        sample_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        num_seeds = 0
        num_inputs = 0
        start = time.time()
198
199
        # Loop over the dataloader to sample the computation dependency graph
        # as a list of blocks.
200
        step_time = []
201
202

        with model.join():
203
            for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
204
205
                tic_step = time.time()
                sample_time += tic_step - start
206
207
208
209
                # fetch features/labels
                batch_inputs, batch_labels = load_subtensor(
                    g, seeds, input_nodes, "cpu"
                )
210
211
212
                batch_labels = batch_labels.long()
                num_seeds += len(blocks[-1].dstdata[dgl.NID])
                num_inputs += len(blocks[0].srcdata[dgl.NID])
213
                # move to target device
214
                blocks = [block.to(device) for block in blocks]
215
                batch_inputs = batch_inputs.to(device)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                batch_labels = batch_labels.to(device)
                # Compute loss and prediction
                start = time.time()
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, batch_labels)
                forward_end = time.time()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_time += forward_end - start
                backward_time += compute_end - forward_end

                optimizer.step()
                update_time += time.time() - compute_end

                step_t = time.time() - tic_step
                step_time.append(step_t)
                iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
                if step % args.log_every == 0:
                    acc = compute_acc(batch_pred, batch_labels)
236
237
238
239
240
241
                    gpu_mem_alloc = (
                        th.cuda.max_memory_allocated() / 1000000
                        if th.cuda.is_available()
                        else 0
                    )
                    print(
242
243
244
                        "Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
                        "Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU "
                        "{:.1f} MB | time {:.3f} s".format(
245
246
247
248
249
250
251
                            g.rank(),
                            epoch,
                            step,
                            loss.item(),
                            acc.item(),
                            np.mean(iter_tput[3:]),
                            gpu_mem_alloc,
252
                            np.sum(step_time[-args.log_every:]),
253
254
                        )
                    )
255
                start = time.time()
256
257

        toc = time.time()
258
        print(
259
260
261
            "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, "
            "forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
            "#inputs: {}".format(
262
263
264
265
266
267
268
269
270
271
                g.rank(),
                toc - tic,
                sample_time,
                forward_time,
                backward_time,
                update_time,
                num_seeds,
                num_inputs,
            )
        )
272
273
        epoch += 1

274
275
        if epoch % args.eval_every == 0 and epoch != 0:
            start = time.time()
276
            val_acc, test_acc = evaluate(
277
                model if args.standalone else model.module,
278
279
280
281
282
283
284
285
286
                g,
                g.ndata["features"],
                g.ndata["labels"],
                val_nid,
                test_nid,
                args.batch_size_eval,
                device,
            )
            print(
287
288
                "Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format
                (
289
290
291
292
                    g.rank(), val_acc, test_acc, time.time() - start
                )
            )

293
294

def main(args):
295
    print(socket.gethostname(), "Initializing DGL dist")
296
    dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
297
    if not args.standalone:
298
        print(socket.gethostname(), "Initializing DGL process group")
299
        th.distributed.init_process_group(backend=args.backend)
300
    print(socket.gethostname(), "Initializing DistGraph")
301
302
303
304
    g = dgl.distributed.DistGraph(
            args.graph_name,
            part_config=args.part_config
        )
305
    print(socket.gethostname(), "rank:", g.rank())
306

307
    pb = g.get_partition_book()
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    if "trainer_id" in g.ndata:
        train_nid = dgl.distributed.node_split(
            g.ndata["train_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
        val_nid = dgl.distributed.node_split(
            g.ndata["val_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
        test_nid = dgl.distributed.node_split(
            g.ndata["test_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
327
    else:
328
329
330
331
332
333
334
335
336
        train_nid = dgl.distributed.node_split(
            g.ndata["train_mask"], pb, force_even=True
        )
        val_nid = dgl.distributed.node_split(
            g.ndata["val_mask"], pb, force_even=True
        )
        test_nid = dgl.distributed.node_split(
            g.ndata["test_mask"], pb, force_even=True
        )
337
    local_nid = pb.partid2nids(pb.partid).detach().numpy()
338
    print(
339
340
        "part {}, train: {} (local: {}), val: {} (local: {}), test: {} "
        "(local: {})".format(
341
342
343
344
345
346
347
348
349
            g.rank(),
            len(train_nid),
            len(np.intersect1d(train_nid.numpy(), local_nid)),
            len(val_nid),
            len(np.intersect1d(val_nid.numpy(), local_nid)),
            len(test_nid),
            len(np.intersect1d(test_nid.numpy(), local_nid)),
        )
    )
350
    del local_nid
351
    if args.num_gpus == -1:
352
        device = th.device("cpu")
353
    else:
354
        dev_id = g.rank() % args.num_gpus
355
        device = th.device("cuda:" + str(dev_id))
356
    n_classes = args.n_classes
357
358
    if n_classes == 0:
        labels = g.ndata["labels"][np.arange(g.num_nodes())]
359
360
        n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
        del labels
361
    print("#labels:", n_classes)
362
363

    # Pack data
364
    in_feats = g.ndata["features"].shape[1]
365
    data = train_nid, val_nid, test_nid, in_feats, n_classes, g
366
367
368
    run(args, device, data)
    print("parent ends")

369
370
371
372
373
374
375
376
377
378
379
380

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
    parser.add_argument("--graph_name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip_config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--part_config", type=str, help="The path to the partition config file"
    )
    parser.add_argument(
381
        "--n_classes", type=int, default=0, help="the number of classes"
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="gloo",
        help="pytorch distributed backend",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_hidden", type=int, default=16)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--fan_out", type=str, default="10,25")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--batch_size_eval", type=int, default=100000)
    parser.add_argument("--log_every", type=int, default=20)
    parser.add_argument("--eval_every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.003)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
    parser.add_argument(
        "--pad-data",
        default=False,
        action="store_true",
415
416
        help="Pad train nid to the same length across machine, to ensure num "
             "of batches to be the same.",
417
418
419
420
421
422
423
    )
    parser.add_argument(
        "--net_type",
        type=str,
        default="socket",
        help="backend net type, 'socket' or 'tensorpipe'",
    )
424
425
426
    args = parser.parse_args()

    print(args)
427
    main(args)