node_classification_sage.py 12.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
This script trains and tests a GraphSAGE model for node classification on
multiple GPUs with distributed data-parallel training (DDP).

Before reading this example, please familiar yourself with graphsage node
classification using neighbor sampling by reading the example in the
`examples/sampling/node_classification.py`

This flowchart describes the main functional sequence of the provided example.
main

├───> Load and preprocess dataset

└───> run (multiprocessing) 

      ├───> Init process group and build distributed SAGE model (HIGHLIGHT)

      ├───> train
      │     │
      │     ├───> NeighborSampler
      │     │
      │     └───> Training loop
      │           │
      │           ├───> SAGE.forward
      │           │
      │           └───> Collect validation accuracy (HIGHLIGHT)

      └───> layerwise_infer

            └───> SAGE.inference

                  ├───> MultiLayerFullNeighborSampler

                  └───> Use a shared output tensor
"""
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import time

import dgl
import dgl.nn as dglnn

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from dgl.data import AsNodePredDataset
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
from dgl.multiprocessing import shared_tensor
from ogb.nodeproppred import DglNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel


class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
65
        # Three-layer GraphSAGE-mean
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hid_size = hid_size
        self.out_size = out_size

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, device, batch_size, use_uva):
        g.ndata["h"] = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["h"])
        for l, layer in enumerate(self.layers):
            dataloader = DataLoader(
                g,
                torch.arange(g.num_nodes(), device=device),
                sampler,
                device=device,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=0,
95
                use_ddp=True,  # use DDP
96
97
                use_uva=use_uva,
            )
98
99
            # In order to prevent running out of GPU memory, allocate a shared
            # output tensor 'y' in host memory.
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            y = shared_tensor(
                (
                    g.num_nodes(),
                    self.hid_size
                    if l != len(self.layers) - 1
                    else self.out_size,
                )
            )
            for input_nodes, output_nodes, blocks in (
                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader
            ):
                x = blocks[0].srcdata["h"]
                h = layer(blocks[0], x)  # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
116
                # Non_blocking (with pinned memory) to accelerate data transfer
117
                y[output_nodes] = h.to(y.device, non_blocking=True)
118
            # Use a barrier to make sure all GPUs are done writing to 'y'
119
120
121
122
123
124
125
            dist.barrier()
            g.ndata["h"] = y if use_uva else y.to(device)

        g.ndata.pop("h")
        return y


126
def evaluate(device, model, g, num_classes, dataloader):
127
128
129
130
131
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
132
            blocks = [block.to(device) for block in blocks]
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
            y_hats.append(model(blocks, x))
    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )


def layerwise_infer(
    proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**10
):
    model.eval()
    with torch.no_grad():
149
150
        if not use_uva:
            g = g.to(device)
151
152
153
154
155
156
157
        pred = model.module.inference(g, device, batch_size, use_uva)
        pred = pred[nid]
        labels = g.ndata["label"][nid].to(pred.device)
    if proc_id == 0:
        acc = MF.accuracy(
            pred, labels, task="multiclass", num_classes=num_classes
        )
158
        print(f"Test accuracy {acc.item():.4f}")
159
160
161
162
163
164


def train(
    proc_id,
    nprocs,
    device,
165
    args,
166
167
168
169
170
171
172
    g,
    num_classes,
    train_idx,
    val_idx,
    model,
    use_uva,
):
173
    # Instantiate a neighbor sampler
174
175
176
177
178
179
180
181
182
183
    if args.mode == "benchmark":
        # A work-around to prevent CUDA running error. For more details, please
        # see https://github.com/dmlc/dgl/issues/6697.
        sampler = NeighborSampler([10, 10, 10], fused=False)
    else:
        sampler = NeighborSampler(
            [10, 10, 10],
            prefetch_node_feats=["feat"],
            prefetch_labels=["label"],
        )
184
185
186
187
188
189
190
191
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
192
        num_workers=args.num_workers,
193
        use_ddp=True,  # To split the set for each process
194
195
196
197
198
199
200
201
202
203
        use_uva=use_uva,
    )
    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
204
        num_workers=args.num_workers,
205
206
207
208
        use_ddp=True,
        use_uva=use_uva,
    )
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
209
    for epoch in range(args.num_epochs):
210
211
212
        t0 = time.time()
        model.train()
        total_loss = 0
213
214
215
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
216
            x = blocks[0].srcdata["feat"]
217
            y = blocks[-1].dstdata["label"].to(torch.int64)
218
219
220
221
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
222
            opt.step()  # Gradients are synchronized in DDP
223
            total_loss += loss
224
225
226
227
228
229
230
231
232
233
234
235
        #####################################################################
        # (HIGHLIGHT) Collect accuracy values from sub-processes and obtain
        # overall accuracy.
        #
        # `torch.distributed.reduce` is used to reduce tensors from all the
        # sub-processes to a specified process, ReduceOp.SUM is used by default.
        #
        # Other multiprocess functions supported by the backend are also
        # available. Please refer to
        # https://pytorch.org/docs/stable/distributed.html
        # for more information.
        #####################################################################
236
        acc = (
237
238
            evaluate(device, model, g, num_classes, val_dataloader).to(device)
            / nprocs
239
240
        )
        t1 = time.time()
241
242
        # Reduce `acc` tensors to process 0.
        dist.reduce(tensor=acc, dst=0)
243
244
        if proc_id == 0:
            print(
245
246
                f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | "
                f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
247
248
249
            )


250
def run(proc_id, nprocs, devices, g, data, args):
251
    # Find corresponding device for current process.
252
253
    device = devices[proc_id]
    torch.cuda.set_device(device)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    #########################################################################
    # (HIGHLIGHT) Build a data-parallel distributed GraphSAGE model.
    #
    # DDP in PyTorch provides data parallelism across the devices specified
    # by the `process_group`. Gradients are synchronized across each model
    # replica.
    #
    # To prepare a training sub-process, there are four steps involved:
    # 1. Initialize the process group
    # 2. Unpack data for the sub-process.
    # 3. Instantiate a GraphSAGE model on the corresponding device.
    # 4. Parallelize the model with `DistributedDataParallel`.
    #
    # For the detailed usage of `DistributedDataParallel`, please refer to
    # PyTorch documentation.
    #########################################################################
270
    dist.init_process_group(
271
        backend="nccl",  # Use NCCL backend for distributed GPU training
272
273
274
275
276
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
    num_classes, train_idx, val_idx, test_idx = data
277
278
279
280
    if args.mode != "benchmark":
        train_idx = train_idx.to(device)
        val_idx = val_idx.to(device)
        g = g.to(device if args.mode == "puregpu" else "cpu")
281
282
283
284
285
    in_size = g.ndata["feat"].shape[1]
    model = SAGE(in_size, 256, num_classes).to(device)
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
286
287

    # Training.
288
289
    use_uva = args.mode == "mixed"

290
291
    if proc_id == 0:
        print("Training...")
292
293
294
295
    train(
        proc_id,
        nprocs,
        device,
296
        args,
297
298
299
300
301
302
303
        g,
        num_classes,
        train_idx,
        val_idx,
        model,
        use_uva,
    )
304
305
306
307

    # Testing.
    if proc_id == 0:
        print("Testing...")
308
    layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)
309
310

    # Cleanup the process group.
311
312
313
314
315
316
317
318
    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        default="mixed",
319
        choices=["mixed", "puregpu", "benchmark"],
320
321
322
323
324
325
326
327
328
329
330
331
332
        help="Training mode. 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
333
        default=10,
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        help="Number of epochs for train.",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="ogbn-products",
        help="Dataset name.",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="dataset",
        help="Root directory of dataset.",
    )
348
349
350
351
352
353
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="Number of workers",
    )
354
355
356
357
358
359
360
361
    args = parser.parse_args()
    devices = list(map(int, args.gpu.split(",")))
    nprocs = len(devices)
    assert (
        torch.cuda.is_available()
    ), f"Must have GPUs to enable multi-gpu training."
    print(f"Training in {args.mode} mode using {nprocs} GPU(s)")

362
    # Load and preprocess the dataset.
363
364
365
366
367
    print("Loading data")
    dataset = AsNodePredDataset(
        DglNodePropPredDataset(args.dataset_name, root=args.dataset_dir)
    )
    g = dataset[0]
368
369
    # Explicitly create desired graph formats before multi-processing to avoid
    # redundant creation in each sub-process and to save memory.
370
371
372
373
    g.create_formats_()
    if args.dataset_name == "ogbn-arxiv":
        g = dgl.to_bidirected(g, copy_ndata=True)
        g = dgl.add_self_loop(g)
374
    # Thread limiting to avoid resource competition.
375
376
377
378
379
380
381
382
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    data = (
        dataset.num_classes,
        dataset.train_idx,
        dataset.val_idx,
        dataset.test_idx,
    )

383
    # To use DDP with n GPUs, spawn up n processes.
384
385
    mp.spawn(
        run,
386
        args=(nprocs, devices, g, data, args),
387
388
        nprocs=nprocs,
    )