node_classification.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
This script trains and tests a GraphSAGE model for node classification on
multiple GPUs using distributed data-parallel training (DDP) and GraphBolt
data loader. 

Before reading this example, please familiar yourself with graphsage node
classification using GtaphBolt data loader by reading the example in the
`examples/sampling/graphbolt/node_classification.py`.

For the usage of DDP provided by PyTorch, please read its documentation:
https://pytorch.org/tutorials/beginner/dist_overview.html and
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParal
lel.html

This flowchart describes the main functional sequence of the provided example:
main

├───> OnDiskDataset pre-processing

└───> run (multiprocessing) 

      ├───> Init process group and build distributed SAGE model (HIGHLIGHT)

      ├───> train
      │     │
      │     ├───> Get GraphBolt dataloader with DistributedItemSampler
      │     │     (HIGHLIGHT)
      │     │
      │     └───> Training loop
      │           │
      │           ├───> SAGE.forward
      │           │
      │           ├───> Validation set evaluation
      │           │
      │           └───> Collect accuracy and loss from all ranks (HIGHLIGHT)

      └───> Test set evaluation
"""
import argparse
import os
41
import time
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP


class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-mean.
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size
        self.out_size = out_size
        # Set the dtype for the layers manually.
        self.set_layer_dtype(torch.float32)

    def set_layer_dtype(self, dtype):
        for layer in self.layers:
            for param in layer.parameters():
                param.data = param.data.to(dtype)

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x


def create_dataloader(
    args,
    graph,
    features,
    itemset,
    device,
    drop_last=False,
    shuffle=True,
    drop_uneven_inputs=False,
):
    ############################################################################
    # [HIGHLIGHT]
    # Get a GraphBolt dataloader for node classification tasks with multi-gpu
    # distributed training. DistributedItemSampler instead of ItemSampler should
    # be used.
    ############################################################################

    ############################################################################
    # [Note]:
    # gb.DistributedItemSampler()
    # [Input]:
    # 'item_set': The current dataset. (e.g. `train_set` or `valid_set`)
    # 'batch_size': Specifies the number of samples to be processed together,
    # referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
    # indicate a subset of the entire dataset that is processed together. This
    # is in contrast to processing the entire dataset, known as a 'full batch'.)
    # 'drop_last': Determines whether the last non-full minibatch should be
    # dropped.
    # 'shuffle': Determines if the items should be shuffled.
    # 'num_replicas': Specifies the number of replicas.
    # 'drop_uneven_inputs': Determines whether the numbers of minibatches on all
    # ranks should be kept the same by dropping uneven minibatches.
    # [Output]:
    # An DistributedItemSampler object for handling mini-batch sampling on
    # multiple replicas.
    ############################################################################
    datapipe = gb.DistributedItemSampler(
        item_set=itemset,
        batch_size=args.batch_size,
        drop_last=drop_last,
        shuffle=shuffle,
        drop_uneven_inputs=drop_uneven_inputs,
    )
    datapipe = datapipe.sample_neighbor(graph, args.fanout)
    datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

    ############################################################################
    # [Note]:
    # datapipe.copy_to() / gb.CopyTo()
    # [Input]:
    # 'device': The specified device that data should be copied to.
    # [Output]:
    # A CopyTo object copying data in the datapipe to a specified device.\
    ############################################################################
    datapipe = datapipe.copy_to(device)
141
    dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers)
142
143
144
145
146
147

    # Return the fully-initialized DataLoader object.
    return dataloader


@torch.no_grad()
148
def evaluate(rank, model, dataloader, num_classes, device):
149
150
151
152
153
154
155
    model.eval()
    y = []
    y_hats = []

    for step, data in (
        tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
    ):
156
        data = data.to_dgl()
157
        blocks = data.blocks
158
        x = data.node_features["feat"]
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        y.append(data.labels)
        y_hats.append(model.module(blocks, x))

    res = MF.accuracy(
        torch.cat(y_hats),
        torch.cat(y),
        task="multiclass",
        num_classes=num_classes,
    )

    return res.to(device)


def train(
    world_size,
    rank,
    args,
176
177
    train_dataloader,
    valid_dataloader,
178
179
180
181
182
183
184
    num_classes,
    model,
    device,
):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(args.epochs):
185
186
        epoch_start = time.time()

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        model.train()
        total_loss = torch.tensor(0, dtype=torch.float).to(device)
        ########################################################################
        # (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
        #
        # The mechanics of Distributed Data Parallel (DDP) training in PyTorch
        # requires the number of inputs are the same for all ranks, otherwise
        # the program may error or hang. To solve it, PyTorch provides Join
        # Context Manager. Please refer to
        # https://pytorch.org/tutorials/advanced/generic_join.html for detailed
        # information.
        #
        # Another method is to set `drop_uneven_inputs` as True in GraphBolt's
        # DistributedItemSampler, which will solve this problem by dropping
        # uneven inputs.
        ########################################################################
        with Join([model]):
            for step, data in (
205
                tqdm.tqdm(enumerate(train_dataloader))
206
                if rank == 0
207
                else enumerate(train_dataloader)
208
            ):
209
210
211
                # Convert data to DGL format.
                data = data.to_dgl()

212
213
                # The input features are from the source nodes in the first
                # layer's computation graph.
214
                x = data.node_features["feat"]
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

                # The ground truth labels are from the destination nodes
                # in the last layer's computation graph.
                y = data.labels

                blocks = data.blocks

                y_hat = model(blocks, x)

                # Compute loss.
                loss = F.cross_entropy(y_hat, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss

        # Evaluate the model.
        if rank == 0:
            print("Validating...")
        acc = (
            evaluate(
                rank,
                model,
240
                valid_dataloader,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
                num_classes,
                device,
            )
            / world_size
        )
        ########################################################################
        # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
        # obtain overall average values.
        #
        # `torch.distributed.reduce` is used to reduce tensors from all the
        # sub-processes to a specified process, ReduceOp.SUM is used by default.
        ########################################################################
        dist.reduce(tensor=acc, dst=0)
        total_loss /= step + 1
        dist.reduce(tensor=total_loss, dst=0)
256
        dist.barrier()
257
258

        epoch_end = time.time()
259
260
261
262
        if rank == 0:
            print(
                f"Epoch {epoch:05d} | "
                f"Average Loss {total_loss.item() / world_size:.4f} | "
263
264
                f"Accuracy {acc.item():.4f} | "
                f"Time {epoch_end - epoch_start:.4f}"
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
            )


def run(rank, world_size, args, devices, dataset):
    # Set up multiprocessing environment.
    device = devices[rank]
    torch.cuda.set_device(device)
    dist.init_process_group(
        backend="nccl",  # Use NCCL backend for distributed GPU training
        init_method="tcp://127.0.0.1:12345",
        world_size=world_size,
        rank=rank,
    )

    graph = dataset.graph
    features = dataset.feature
    train_set = dataset.tasks[0].train_set
    valid_set = dataset.tasks[0].validation_set
283
    test_set = dataset.tasks[0].test_set
284
285
286
287
288
289
290
291
292
293
294
    args.fanout = list(map(int, args.fanout.split(",")))
    num_classes = dataset.tasks[0].metadata["num_classes"]

    in_size = features.size("node", None, "feat")[0]
    hidden_size = 256
    out_size = num_classes

    # Create GraphSAGE model. It should be copied onto a GPU as a replica.
    model = SAGE(in_size, hidden_size, out_size).to(device)
    model = DDP(model)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    # Create data loaders.
    train_dataloader = create_dataloader(
        args,
        graph,
        features,
        train_set,
        device,
        drop_last=False,
        shuffle=True,
        drop_uneven_inputs=False,
    )
    valid_dataloader = create_dataloader(
        args,
        graph,
        features,
        valid_set,
        device,
        drop_last=False,
        shuffle=False,
        drop_uneven_inputs=False,
    )
    test_dataloader = create_dataloader(
        args,
        graph,
        features,
        test_set,
        device,
        drop_last=False,
        shuffle=False,
        drop_uneven_inputs=False,
    )

327
328
329
330
331
332
333
    # Model training.
    if rank == 0:
        print("Training...")
    train(
        world_size,
        rank,
        args,
334
335
        train_dataloader,
        valid_dataloader,
336
337
338
339
340
341
342
343
344
345
346
347
        num_classes,
        model,
        device,
    )

    # Test the model.
    if rank == 0:
        print("Testing...")
    test_acc = (
        evaluate(
            rank,
            model,
348
349
350
            test_dataloader,
            num_classes,
            device,
351
352
353
354
        )
        / world_size
    )
    dist.reduce(tensor=test_acc, dst=0)
355
    dist.barrier()
356
    if rank == 0:
357
        print(f"Test Accuracy {test_acc.item():.4f}")
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390


def parse_args():
    parser = argparse.ArgumentParser(
        description="A script does a multi-gpu training on a GraphSAGE model "
        "for node classification using GraphBolt dataloader."
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of training epochs."
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.001,
        help="Learning rate for optimization.",
    )
    parser.add_argument(
        "--batch-size", type=int, default=1024, help="Batch size for training."
    )
    parser.add_argument(
        "--fanout",
        type=str,
        default="10,10,10",
        help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
        " identical with the number of layers in your model. Default: 15,10,5",
    )
391
392
393
    parser.add_argument(
        "--num-workers", type=int, default=0, help="The number of processes."
    )
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    if not torch.cuda.is_available():
        print(f"Multi-gpu training needs to be in gpu mode.")
        exit(0)

    devices = list(map(int, args.gpu.split(",")))
    world_size = len(devices)

    print(f"Training with {world_size} gpus.")

    # Load and preprocess dataset.
    dataset = gb.BuiltinDataset("ogbn-products").load()

411
412
413
    # Thread limiting to avoid resource competition.
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)

414
415
416
417
418
419
420
    mp.set_sharing_strategy("file_system")
    mp.spawn(
        run,
        args=(world_size, args, devices, dataset),
        nprocs=world_size,
        join=True,
    )