node_classification_sage.py 11.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
This script trains and tests a GraphSAGE model for node classification on
multiple GPUs with distributed data-parallel training (DDP).

Before reading this example, please familiar yourself with graphsage node
classification using neighbor sampling by reading the example in the
`examples/sampling/node_classification.py`

This flowchart describes the main functional sequence of the provided example.
main

├───> Load and preprocess dataset

└───> run (multiprocessing) 

      ├───> Init process group and build distributed SAGE model (HIGHLIGHT)

      ├───> train
      │     │
      │     ├───> NeighborSampler
      │     │
      │     └───> Training loop
      │           │
      │           ├───> SAGE.forward
      │           │
      │           └───> Collect validation accuracy (HIGHLIGHT)

      └───> layerwise_infer

            └───> SAGE.inference

                  ├───> MultiLayerFullNeighborSampler

                  └───> Use a shared output tensor
"""
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import time

import dgl
import dgl.nn as dglnn

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from dgl.data import AsNodePredDataset
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
from dgl.multiprocessing import shared_tensor
from ogb.nodeproppred import DglNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel


class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
65
        # Three-layer GraphSAGE-mean
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hid_size = hid_size
        self.out_size = out_size

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, device, batch_size, use_uva):
        g.ndata["h"] = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["h"])
        for l, layer in enumerate(self.layers):
            dataloader = DataLoader(
                g,
                torch.arange(g.num_nodes(), device=device),
                sampler,
                device=device,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=0,
95
                use_ddp=True,  # use DDP
96
97
                use_uva=use_uva,
            )
98
99
            # In order to prevent running out of GPU memory, allocate a shared
            # output tensor 'y' in host memory.
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            y = shared_tensor(
                (
                    g.num_nodes(),
                    self.hid_size
                    if l != len(self.layers) - 1
                    else self.out_size,
                )
            )
            for input_nodes, output_nodes, blocks in (
                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader
            ):
                x = blocks[0].srcdata["h"]
                h = layer(blocks[0], x)  # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
116
                # Non_blocking (with pinned memory) to accelerate data transfer
117
                y[output_nodes] = h.to(y.device, non_blocking=True)
118
            # Use a barrier to make sure all GPUs are done writing to 'y'
119
120
121
122
123
124
125
            dist.barrier()
            g.ndata["h"] = y if use_uva else y.to(device)

        g.ndata.pop("h")
        return y


126
def evaluate(device, model, g, num_classes, dataloader):
127
128
129
130
131
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
132
            blocks = [block.to(device) for block in blocks]
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
            y_hats.append(model(blocks, x))
    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )


def layerwise_infer(
    proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**10
):
    model.eval()
    with torch.no_grad():
149
150
        if not use_uva:
            g = g.to(device)
151
152
153
154
155
156
157
        pred = model.module.inference(g, device, batch_size, use_uva)
        pred = pred[nid]
        labels = g.ndata["label"][nid].to(pred.device)
    if proc_id == 0:
        acc = MF.accuracy(
            pred, labels, task="multiclass", num_classes=num_classes
        )
158
        print(f"Test accuracy {acc.item():.4f}")
159
160
161
162
163
164


def train(
    proc_id,
    nprocs,
    device,
165
    args,
166
167
168
169
170
171
172
    g,
    num_classes,
    train_idx,
    val_idx,
    model,
    use_uva,
):
173
    # Instantiate a neighbor sampler
174
    sampler = NeighborSampler(
175
176
177
178
        [10, 10, 10],
        prefetch_node_feats=["feat"],
        prefetch_labels=["label"],
        fused=(args.mode != "benchmark"),
179
180
181
182
183
184
185
186
187
    )
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
188
        num_workers=args.num_workers,
189
        use_ddp=True,  # To split the set for each process
190
191
192
193
194
195
196
197
198
199
        use_uva=use_uva,
    )
    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
200
        num_workers=args.num_workers,
201
202
203
204
        use_ddp=True,
        use_uva=use_uva,
    )
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
205
    for epoch in range(args.num_epochs):
206
207
208
        t0 = time.time()
        model.train()
        total_loss = 0
209
210
211
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
212
            x = blocks[0].srcdata["feat"]
213
            y = blocks[-1].dstdata["label"].to(torch.int64)
214
215
216
217
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
218
            opt.step()  # Gradients are synchronized in DDP
219
            total_loss += loss
220
221
222
223
224
225
226
227
228
229
230
231
        #####################################################################
        # (HIGHLIGHT) Collect accuracy values from sub-processes and obtain
        # overall accuracy.
        #
        # `torch.distributed.reduce` is used to reduce tensors from all the
        # sub-processes to a specified process, ReduceOp.SUM is used by default.
        #
        # Other multiprocess functions supported by the backend are also
        # available. Please refer to
        # https://pytorch.org/docs/stable/distributed.html
        # for more information.
        #####################################################################
232
        acc = (
233
234
            evaluate(device, model, g, num_classes, val_dataloader).to(device)
            / nprocs
235
236
        )
        t1 = time.time()
237
238
        # Reduce `acc` tensors to process 0.
        dist.reduce(tensor=acc, dst=0)
239
240
        if proc_id == 0:
            print(
241
242
                f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | "
                f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
243
244
245
            )


246
def run(proc_id, nprocs, devices, g, data, args):
247
    # Find corresponding device for current process.
248
249
    device = devices[proc_id]
    torch.cuda.set_device(device)
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    #########################################################################
    # (HIGHLIGHT) Build a data-parallel distributed GraphSAGE model.
    #
    # DDP in PyTorch provides data parallelism across the devices specified
    # by the `process_group`. Gradients are synchronized across each model
    # replica.
    #
    # To prepare a training sub-process, there are four steps involved:
    # 1. Initialize the process group
    # 2. Unpack data for the sub-process.
    # 3. Instantiate a GraphSAGE model on the corresponding device.
    # 4. Parallelize the model with `DistributedDataParallel`.
    #
    # For the detailed usage of `DistributedDataParallel`, please refer to
    # PyTorch documentation.
    #########################################################################
266
    dist.init_process_group(
267
        backend="nccl",  # Use NCCL backend for distributed GPU training
268
269
270
271
272
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
    num_classes, train_idx, val_idx, test_idx = data
273
274
275
276
    if args.mode != "benchmark":
        train_idx = train_idx.to(device)
        val_idx = val_idx.to(device)
        g = g.to(device if args.mode == "puregpu" else "cpu")
277
278
279
280
281
    in_size = g.ndata["feat"].shape[1]
    model = SAGE(in_size, 256, num_classes).to(device)
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
282
283

    # Training.
284
285
    use_uva = args.mode == "mixed"

286
287
    if proc_id == 0:
        print("Training...")
288
289
290
291
    train(
        proc_id,
        nprocs,
        device,
292
        args,
293
294
295
296
297
298
299
        g,
        num_classes,
        train_idx,
        val_idx,
        model,
        use_uva,
    )
300
301
302
303

    # Testing.
    if proc_id == 0:
        print("Testing...")
304
    layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)
305
306

    # Cleanup the process group.
307
308
309
310
311
312
313
314
    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        default="mixed",
315
        choices=["mixed", "puregpu", "benchmark"],
316
317
318
319
320
321
322
323
324
325
326
327
328
        help="Training mode. 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
329
        default=10,
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        help="Number of epochs for train.",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="ogbn-products",
        help="Dataset name.",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="dataset",
        help="Root directory of dataset.",
    )
344
345
346
347
348
349
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="Number of workers",
    )
350
351
352
353
354
355
356
357
    args = parser.parse_args()
    devices = list(map(int, args.gpu.split(",")))
    nprocs = len(devices)
    assert (
        torch.cuda.is_available()
    ), f"Must have GPUs to enable multi-gpu training."
    print(f"Training in {args.mode} mode using {nprocs} GPU(s)")

358
    # Load and preprocess the dataset.
359
360
361
362
363
    print("Loading data")
    dataset = AsNodePredDataset(
        DglNodePropPredDataset(args.dataset_name, root=args.dataset_dir)
    )
    g = dataset[0]
364
365
    # Explicitly create desired graph formats before multi-processing to avoid
    # redundant creation in each sub-process and to save memory.
366
367
368
369
    g.create_formats_()
    if args.dataset_name == "ogbn-arxiv":
        g = dgl.to_bidirected(g, copy_ndata=True)
        g = dgl.add_self_loop(g)
370
    # Thread limiting to avoid resource competition.
371
372
373
374
375
376
377
378
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    data = (
        dataset.num_classes,
        dataset.train_idx,
        dataset.val_idx,
        dataset.test_idx,
    )

379
    # To use DDP with n GPUs, spawn up n processes.
380
381
    mp.spawn(
        run,
382
        args=(nprocs, devices, g, data, args),
383
384
        nprocs=nprocs,
    )