train_dist.py 14.1 KB
Newer Older
1
2
3
import argparse
import socket
import time
4
from contextlib import contextmanager
5

Rhett Ying's avatar
Rhett Ying committed
6
7
8
import dgl
import dgl.nn.pytorch as dglnn

9
import numpy as np
10
11
12
13
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
14
import tqdm
Rhett Ying's avatar
Rhett Ying committed
15

16

17
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
18
19
20
    """
    Copys features and labels of a set of nodes onto GPU.
    """
21
22
23
24
    batch_inputs = (
        g.ndata["features"][input_nodes].to(device) if load_feat else None
    )
    batch_labels = g.ndata["labels"][seeds].to(device)
25
26
    return batch_inputs, batch_labels

27

28
class DistSAGE(nn.Module):
29
30
31
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
32
33
34
35
36
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
37
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
38
        for i in range(1, n_layers - 1):
39
40
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
41
42
43
44
45
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
46
        for i, (layer, block) in enumerate(zip(self.layers, blocks)):
47
            h = layer(block, h)
48
            if i != len(self.layers) - 1:
49
50
51
                h = self.activation(h)
                h = self.dropout(h)
        return h
52
53
54

    def inference(self, g, x, batch_size, device):
        """
55
56
57
        Inference with the GraphSAGE model on full neighbors (i.e. without
        neighbor sampling).

58
59
60
        g : the entire graph.
        x : the input of entire node set.

61
        Distributed layer-wise inference.
62
        """
63
64
65
66
67
        # During inference with sampling, multi-layer blocks are very
        # inefficient because lots of computations in the first few layers
        # are repeated. Therefore, we compute the representation of all nodes
        # layer by layer.  The nodes on each layer are of course splitted in
        # batches.
68
        # TODO: can we standardize this?
69
        nodes = dgl.distributed.node_split(
70
            np.arange(g.num_nodes()),
71
72
73
74
            g.get_partition_book(),
            force_even=True,
        )
        y = dgl.distributed.DistTensor(
75
            (g.num_nodes(), self.n_hidden),
76
77
78
79
            th.float32,
            "h",
            persistent=True,
        )
80
81
        for i, layer in enumerate(self.layers):
            if i == len(self.layers) - 1:
82
                y = dgl.distributed.DistTensor(
83
                    (g.num_nodes(), self.n_classes),
84
85
86
87
                    th.float32,
                    "h_last",
                    persistent=True,
                )
Rhett Ying's avatar
Rhett Ying committed
88
            print(f"|V|={g.num_nodes()}, eval batch size: {batch_size}")
89
90
91
92
93
94

            sampler = dgl.dataloading.NeighborSampler([-1])
            dataloader = dgl.dataloading.DistNodeDataLoader(
                g,
                nodes,
                sampler,
95
96
                batch_size=batch_size,
                shuffle=False,
97
98
                drop_last=False,
            )
99

100
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
101
                block = blocks[0].to(device)
102
                h = x[input_nodes].to(device)
103
                h_dst = h[: block.number_of_dst_nodes()]
104
                h = layer(block, (h, h_dst))
105
                if i != len(self.layers) - 1:
106
107
108
109
110
111
112
113
114
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

115
116
117
118
119
    @contextmanager
    def join(self):
        """dummy join for standalone"""
        yield

120

121
122
123
124
125
126
127
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
143
144
145
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
        pred[test_nid], labels[test_nid]
    )
146

147

148
149
def run(args, device, data):
    # Unpack data
150
    train_nid, val_nid, test_nid, in_feats, n_classes, g = data
151
    shuffle = True
152
153
154
    # prefetch_node_feats/prefetch_labels are not supported for DistGraph yet.
    sampler = dgl.dataloading.NeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(",")]
155
    )
156
157
158
159
    dataloader = dgl.dataloading.DistNodeDataLoader(
        g,
        train_nid,
        sampler,
160
        batch_size=args.batch_size,
161
        shuffle=shuffle,
162
163
        drop_last=False,
    )
164
    # Define model and optimizer
165
166
167
168
169
170
171
172
    model = DistSAGE(
        in_feats,
        args.num_hidden,
        n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
    )
173
    model = model.to(device)
174
    if not args.standalone:
175
176
177
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
178
179
180
            model = th.nn.parallel.DistributedDataParallel(
                model, device_ids=[device], output_device=device
            )
181
182
183
184
185
186
187
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    iter_tput = []
    epoch = 0
188
    for epoch in range(args.num_epochs):
189
190
191
192
193
194
195
196
197
        tic = time.time()

        sample_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        num_seeds = 0
        num_inputs = 0
        start = time.time()
198
199
        # Loop over the dataloader to sample the computation dependency graph
        # as a list of blocks.
200
        step_time = []
201
202

        with model.join():
203
            for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
204
205
                tic_step = time.time()
                sample_time += tic_step - start
206
207
208
209
                # fetch features/labels
                batch_inputs, batch_labels = load_subtensor(
                    g, seeds, input_nodes, "cpu"
                )
210
211
212
                batch_labels = batch_labels.long()
                num_seeds += len(blocks[-1].dstdata[dgl.NID])
                num_inputs += len(blocks[0].srcdata[dgl.NID])
213
                # move to target device
214
                blocks = [block.to(device) for block in blocks]
215
                batch_inputs = batch_inputs.to(device)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                batch_labels = batch_labels.to(device)
                # Compute loss and prediction
                start = time.time()
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, batch_labels)
                forward_end = time.time()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_time += forward_end - start
                backward_time += compute_end - forward_end

                optimizer.step()
                update_time += time.time() - compute_end

                step_t = time.time() - tic_step
                step_time.append(step_t)
                iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
                if step % args.log_every == 0:
                    acc = compute_acc(batch_pred, batch_labels)
236
237
238
239
240
241
                    gpu_mem_alloc = (
                        th.cuda.max_memory_allocated() / 1000000
                        if th.cuda.is_available()
                        else 0
                    )
                    print(
242
243
244
                        "Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
                        "Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU "
                        "{:.1f} MB | time {:.3f} s".format(
245
246
247
248
249
250
251
                            g.rank(),
                            epoch,
                            step,
                            loss.item(),
                            acc.item(),
                            np.mean(iter_tput[3:]),
                            gpu_mem_alloc,
Rhett Ying's avatar
Rhett Ying committed
252
                            np.sum(step_time[-args.log_every :]),
253
254
                        )
                    )
255
                start = time.time()
256
257

        toc = time.time()
258
        print(
259
260
261
            "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, "
            "forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
            "#inputs: {}".format(
262
263
264
265
266
267
268
269
270
271
                g.rank(),
                toc - tic,
                sample_time,
                forward_time,
                backward_time,
                update_time,
                num_seeds,
                num_inputs,
            )
        )
272
273
        epoch += 1

274
275
        if epoch % args.eval_every == 0 and epoch != 0:
            start = time.time()
276
            val_acc, test_acc = evaluate(
277
                model if args.standalone else model.module,
278
279
280
281
282
283
284
285
286
                g,
                g.ndata["features"],
                g.ndata["labels"],
                val_nid,
                test_nid,
                args.batch_size_eval,
                device,
            )
            print(
Rhett Ying's avatar
Rhett Ying committed
287
                "Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
288
289
290
291
                    g.rank(), val_acc, test_acc, time.time() - start
                )
            )

292
293

def main(args):
294
    print(socket.gethostname(), "Initializing DGL dist")
295
    dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
296
    if not args.standalone:
297
        print(socket.gethostname(), "Initializing DGL process group")
298
        th.distributed.init_process_group(backend=args.backend)
299
    print(socket.gethostname(), "Initializing DistGraph")
Rhett Ying's avatar
Rhett Ying committed
300
    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
301
    print(socket.gethostname(), "rank:", g.rank())
302

303
    pb = g.get_partition_book()
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    if "trainer_id" in g.ndata:
        train_nid = dgl.distributed.node_split(
            g.ndata["train_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
        val_nid = dgl.distributed.node_split(
            g.ndata["val_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
        test_nid = dgl.distributed.node_split(
            g.ndata["test_mask"],
            pb,
            force_even=True,
            node_trainer_ids=g.ndata["trainer_id"],
        )
323
    else:
324
325
326
327
328
329
330
331
332
        train_nid = dgl.distributed.node_split(
            g.ndata["train_mask"], pb, force_even=True
        )
        val_nid = dgl.distributed.node_split(
            g.ndata["val_mask"], pb, force_even=True
        )
        test_nid = dgl.distributed.node_split(
            g.ndata["test_mask"], pb, force_even=True
        )
333
    local_nid = pb.partid2nids(pb.partid).detach().numpy()
334
    print(
335
336
        "part {}, train: {} (local: {}), val: {} (local: {}), test: {} "
        "(local: {})".format(
337
338
339
340
341
342
343
344
345
            g.rank(),
            len(train_nid),
            len(np.intersect1d(train_nid.numpy(), local_nid)),
            len(val_nid),
            len(np.intersect1d(val_nid.numpy(), local_nid)),
            len(test_nid),
            len(np.intersect1d(test_nid.numpy(), local_nid)),
        )
    )
346
    del local_nid
347
    if args.num_gpus == -1:
348
        device = th.device("cpu")
349
    else:
350
        dev_id = g.rank() % args.num_gpus
351
        device = th.device("cuda:" + str(dev_id))
352
    n_classes = args.n_classes
353
354
    if n_classes == 0:
        labels = g.ndata["labels"][np.arange(g.num_nodes())]
355
356
        n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
        del labels
357
    print("#labels:", n_classes)
358
359

    # Pack data
360
    in_feats = g.ndata["features"].shape[1]
361
    data = train_nid, val_nid, test_nid, in_feats, n_classes, g
362
363
364
    run(args, device, data)
    print("parent ends")

365
366
367
368
369
370
371
372
373
374
375
376

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
    parser.add_argument("--graph_name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip_config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--part_config", type=str, help="The path to the partition config file"
    )
    parser.add_argument(
377
        "--n_classes", type=int, default=0, help="the number of classes"
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="gloo",
        help="pytorch distributed backend",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_hidden", type=int, default=16)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--fan_out", type=str, default="10,25")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--batch_size_eval", type=int, default=100000)
    parser.add_argument("--log_every", type=int, default=20)
    parser.add_argument("--eval_every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.003)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
    parser.add_argument(
        "--pad-data",
        default=False,
        action="store_true",
411
        help="Pad train nid to the same length across machine, to ensure num "
Rhett Ying's avatar
Rhett Ying committed
412
        "of batches to be the same.",
413
414
415
416
417
418
419
    )
    parser.add_argument(
        "--net_type",
        type=str,
        default="socket",
        help="backend net type, 'socket' or 'tensorpipe'",
    )
420
421
422
    args = parser.parse_args()

    print(args)
423
    main(args)