node_classification_sage.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
This script trains and tests a GraphSAGE model for node classification on
multiple GPUs with distributed data-parallel training (DDP).

Before reading this example, please familiar yourself with graphsage node
classification using neighbor sampling by reading the example in the
`examples/sampling/node_classification.py`

This flowchart describes the main functional sequence of the provided example.
main

├───> Load and preprocess dataset

└───> run (multiprocessing) 

      ├───> Init process group and build distributed SAGE model (HIGHLIGHT)

      ├───> train
      │     │
      │     ├───> NeighborSampler
      │     │
      │     └───> Training loop
      │           │
      │           ├───> SAGE.forward
      │           │
      │           └───> Collect validation accuracy (HIGHLIGHT)

      └───> layerwise_infer

            └───> SAGE.inference

                  ├───> MultiLayerFullNeighborSampler

                  └───> Use a shared output tensor
"""
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import time

import dgl
import dgl.nn as dglnn

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from dgl.data import AsNodePredDataset
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
from dgl.multiprocessing import shared_tensor
from ogb.nodeproppred import DglNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel


class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
65
        # Three-layer GraphSAGE-mean
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hid_size = hid_size
        self.out_size = out_size

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, device, batch_size, use_uva):
        g.ndata["h"] = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["h"])
        for l, layer in enumerate(self.layers):
            dataloader = DataLoader(
                g,
                torch.arange(g.num_nodes(), device=device),
                sampler,
                device=device,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=0,
95
                use_ddp=True,  # use DDP
96
97
                use_uva=use_uva,
            )
98
99
            # In order to prevent running out of GPU memory, allocate a shared
            # output tensor 'y' in host memory.
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            y = shared_tensor(
                (
                    g.num_nodes(),
                    self.hid_size
                    if l != len(self.layers) - 1
                    else self.out_size,
                )
            )
            for input_nodes, output_nodes, blocks in (
                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader
            ):
                x = blocks[0].srcdata["h"]
                h = layer(blocks[0], x)  # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
116
                # Non_blocking (with pinned memory) to accelerate data transfer
117
                y[output_nodes] = h.to(y.device, non_blocking=True)
118
            # Use a barrier to make sure all GPUs are done writing to 'y'
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            dist.barrier()
            g.ndata["h"] = y if use_uva else y.to(device)

        g.ndata.pop("h")
        return y


def evaluate(model, g, num_classes, dataloader):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
            y_hats.append(model(blocks, x))
    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )


def layerwise_infer(
    proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**10
):
    model.eval()
    with torch.no_grad():
        pred = model.module.inference(g, device, batch_size, use_uva)
        pred = pred[nid]
        labels = g.ndata["label"][nid].to(pred.device)
    if proc_id == 0:
        acc = MF.accuracy(
            pred, labels, task="multiclass", num_classes=num_classes
        )
155
        print(f"Test accuracy {acc.item():.4f}")
156
157
158
159
160
161
162
163
164
165
166
167
168
169


def train(
    proc_id,
    nprocs,
    device,
    g,
    num_classes,
    train_idx,
    val_idx,
    model,
    use_uva,
    num_epochs,
):
170
    # Instantiate a neighbor sampler
171
172
173
174
175
176
177
178
179
180
181
182
    sampler = NeighborSampler(
        [10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
    )
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
183
        use_ddp=True,  # To split the set for each process
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        use_uva=use_uva,
    )
    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_ddp=True,
        use_uva=use_uva,
    )
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    for epoch in range(num_epochs):
        t0 = time.time()
        model.train()
        total_loss = 0
203
204
205
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
206
207
208
209
210
211
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
212
            opt.step()  # Gradients are synchronized in DDP
213
            total_loss += loss
214
215
216
217
218
219
220
221
222
223
224
225
        #####################################################################
        # (HIGHLIGHT) Collect accuracy values from sub-processes and obtain
        # overall accuracy.
        #
        # `torch.distributed.reduce` is used to reduce tensors from all the
        # sub-processes to a specified process, ReduceOp.SUM is used by default.
        #
        # Other multiprocess functions supported by the backend are also
        # available. Please refer to
        # https://pytorch.org/docs/stable/distributed.html
        # for more information.
        #####################################################################
226
227
228
229
        acc = (
            evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs
        )
        t1 = time.time()
230
231
        # Reduce `acc` tensors to process 0.
        dist.reduce(tensor=acc, dst=0)
232
233
        if proc_id == 0:
            print(
234
235
                f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | "
                f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
236
237
238
239
            )


def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
240
    # Find corresponding device for current process.
241
242
    device = devices[proc_id]
    torch.cuda.set_device(device)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    #########################################################################
    # (HIGHLIGHT) Build a data-parallel distributed GraphSAGE model.
    #
    # DDP in PyTorch provides data parallelism across the devices specified
    # by the `process_group`. Gradients are synchronized across each model
    # replica.
    #
    # To prepare a training sub-process, there are four steps involved:
    # 1. Initialize the process group
    # 2. Unpack data for the sub-process.
    # 3. Instantiate a GraphSAGE model on the corresponding device.
    # 4. Parallelize the model with `DistributedDataParallel`.
    #
    # For the detailed usage of `DistributedDataParallel`, please refer to
    # PyTorch documentation.
    #########################################################################
259
    dist.init_process_group(
260
        backend="nccl",  # Use NCCL backend for distributed GPU training
261
262
263
264
265
266
267
268
269
270
271
272
273
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
    num_classes, train_idx, val_idx, test_idx = data
    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
    g = g.to(device if mode == "puregpu" else "cpu")
    in_size = g.ndata["feat"].shape[1]
    model = SAGE(in_size, 256, num_classes).to(device)
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
274
275

    # Training.
276
    use_uva = mode == "mixed"
277
278
    if proc_id == 0:
        print("Training...")
279
280
281
282
283
284
285
286
287
288
289
290
    train(
        proc_id,
        nprocs,
        device,
        g,
        num_classes,
        train_idx,
        val_idx,
        model,
        use_uva,
        num_epochs,
    )
291
292
293
294

    # Testing.
    if proc_id == 0:
        print("Testing...")
295
    layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)
296
297

    # Cleanup the process group.
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["mixed", "puregpu"],
        help="Training mode. 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
        default=20,
        help="Number of epochs for train.",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="ogbn-products",
        help="Dataset name.",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="dataset",
        help="Root directory of dataset.",
    )
    args = parser.parse_args()
    devices = list(map(int, args.gpu.split(",")))
    nprocs = len(devices)
    assert (
        torch.cuda.is_available()
    ), f"Must have GPUs to enable multi-gpu training."
    print(f"Training in {args.mode} mode using {nprocs} GPU(s)")

343
    # Load and preprocess the dataset.
344
345
346
347
348
    print("Loading data")
    dataset = AsNodePredDataset(
        DglNodePropPredDataset(args.dataset_name, root=args.dataset_dir)
    )
    g = dataset[0]
349
350
    # Explicitly create desired graph formats before multi-processing to avoid
    # redundant creation in each sub-process and to save memory.
351
352
353
354
    g.create_formats_()
    if args.dataset_name == "ogbn-arxiv":
        g = dgl.to_bidirected(g, copy_ndata=True)
        g = dgl.add_self_loop(g)
355
    # Thread limiting to avoid resource competition.
356
357
358
359
360
361
362
363
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    data = (
        dataset.num_classes,
        dataset.train_idx,
        dataset.val_idx,
        dataset.test_idx,
    )

364
    # To use DDP with n GPUs, spawn up n processes.
365
366
367
368
369
    mp.spawn(
        run,
        args=(nprocs, devices, g, data, args.mode, args.num_epochs),
        nprocs=nprocs,
    )