multi_gpu_node_classification.py 5.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.optim
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
12
import tqdm
13
14
15
16
17
18
19
20
21

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(0.5)
22
23
        self.n_hidden = n_hidden
        self.n_classes = n_classes
24
25
26
27
28
29
30
31
32
33

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

34
35
36
37
38
    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
        g.ndata['h'] = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
39
        dataloader = dgl.dataloading.DataLoader(
40
41
42
43
44
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers,
                persistent_workers=(num_workers > 0))
        if buffer_device is None:
            buffer_device = device
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        for l, layer in enumerate(self.layers):
            y = torch.zeros(
                g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
                device=buffer_device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = blocks[0].srcdata['h']
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                y[output_nodes] = h.to(buffer_device)
            g.ndata['h'] = y
        return y


61
def train(rank, world_size, graph, num_classes, split_idx):
62
63
64
65
66
67
68
69
    torch.cuda.set_device(rank)
    dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank)

    model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda()
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
70

71
72
    train_idx = train_idx.to('cuda')
    valid_idx = valid_idx.to('cuda')
73
74

    sampler = dgl.dataloading.NeighborSampler(
75
            [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
76
    train_dataloader = dgl.dataloading.DataLoader(
77
78
            graph, train_idx, sampler,
            device='cuda', batch_size=1000, shuffle=True, drop_last=False,
79
            num_workers=0, use_ddp=True, use_uva=True)
80
    valid_dataloader = dgl.dataloading.DataLoader(
81
            graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
82
            drop_last=False, num_workers=0, use_uva=True)
83
84
85

    durations = []
    for _ in range(10):
86
        model.train()
87
        t0 = time.time()
88
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label'][:, 0]
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            if it % 20 == 0:
                acc = MF.accuracy(y_hat, y)
                mem = torch.cuda.max_memory_allocated() / 1000000
                print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
        tt = time.time()
        if rank == 0:
            print(tt - t0)
            durations.append(tt - t0)
104
105
106
107
108
109
110
111
112
113
114
115
116

            model.eval()
            ys = []
            y_hats = []
            for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
                with torch.no_grad():
                    x = blocks[0].srcdata['feat']
                    ys.append(blocks[-1].dstdata['label'])
                    y_hats.append(model.module(blocks, x))
            acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
            print('Validation acc:', acc.item())
        dist.barrier()

117
118
    if rank == 0:
        print(np.mean(durations[4:]), np.std(durations[4:]))
119
120
121
122
123
        model.eval()
        with torch.no_grad():
            pred = model.module.inference(graph, 'cuda', 1000, 12, graph.device)
            acc = MF.accuracy(pred.to(graph.device), graph.ndata['label'])
            print('Test acc:', acc.item())
124
125
126
127

if __name__ == '__main__':
    dataset = DglNodePropPredDataset('ogbn-products')
    graph, labels = dataset[0]
128
129
    graph.ndata['label'] = labels
    graph.create_formats_()     # must be called before mp.spawn().
130
131
132
133
134
135
    split_idx = dataset.get_idx_split()
    num_classes = dataset.num_classes
    n_procs = 4

    # Tested with mp.spawn and fork.  Both worked and got 4s per epoch with 4 GPUs
    # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
136
    import torch.multiprocessing as mp
137
    mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs)