multi_gpu_node_classification.py 7.28 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.optim
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
9
from dgl.multiprocessing import shared_tensor
10
11
12
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
13
import tqdm
14

15

16
17
18
19
20
21
22
23
class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(0.5)
24
25
        self.n_hidden = n_hidden
        self.n_classes = n_classes
26

27
28
29
30
31
32
33
    def _forward_layer(self, l, block, x):
        h = self.layers[l](block, x)
        if l != len(self.layers) - 1:
            h = F.relu(h)
            h = self.dropout(h)
        return h

34
35
36
    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
37
            h = self._forward_layer(l, blocks[l], h)
38
39
        return h

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def inference(self, g, device, batch_size):
        """
        Perform inference in layer-major order rather than batch-major order.
        That is, infer the first layer for the entire graph, and store the
        intermediate values h_0, before infering the second layer to generate
        h_1. This is done for two reasons: 1) it limits the effect of node
        degree on the amount of memory used as it only proccesses 1-hop
        neighbors at a time, and 2) it reduces the total amount of computation
        required as each node is only processed once per layer.

        Parameters
        ----------
            g : DGLGraph
                The graph to perform inference on.
            device : context
                The device this process should use for inference
            batch_size : int
                The number of items to collect in a batch.

        Returns
        -------
            tensor
                The predictions for all nodes in the graph.
        """
64
65
        g.ndata['h'] = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
66
67
68

        for l, layer in enumerate(self.layers):
            dataloader = dgl.dataloading.DataLoader(
69
70
71
72
                g, torch.arange(g.num_nodes(), device=device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False,
                num_workers=0, use_ddp=True, use_uva=True)
            # in order to prevent running out of GPU memory, we allocate a
73
            # shared output tensor 'y' in host memory
74
            y = shared_tensor(
75
                    (g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes))
76
77
78

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
                    if dist.get_rank() == 0 else dataloader:
79
                x = blocks[0].srcdata['h']
80
81
82
83
84
85
86
87
88
89
90
                h = self._forward_layer(l, blocks[0], x)
                y[output_nodes] = h.to(y.device)
            # make sure all GPUs are done writing to 'y'
            dist.barrier()
            if l + 1 < len(self.layers):
                # assign the output features of this layer as the new input
                # features for the next layer
                g.ndata['h'] = y
            else:
                # remove the intermediate data from the graph
                g.ndata.pop('h')
91
92
93
        return y


94
def train(rank, world_size, graph, num_classes, split_idx):
95
96
97
98
99
100
101
102
    torch.cuda.set_device(rank)
    dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank)

    model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda()
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
103

104
    # move ids to GPU
105
106
    train_idx = train_idx.to('cuda')
    valid_idx = valid_idx.to('cuda')
107

108
109
110
111
    # For training, each process/GPU will get a subset of the
    # train_idx/valid_idx, and generate mini-batches indepednetly. This allows
    # the only communication neccessary in training to be the all-reduce for
    # the gradients performed by the DDP wrapper (created above).
112
    sampler = dgl.dataloading.NeighborSampler(
113
            [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
114
    train_dataloader = dgl.dataloading.DataLoader(
115
            graph, train_idx, sampler,
116
            device='cuda', batch_size=1024, shuffle=True, drop_last=False,
117
            num_workers=0, use_ddp=True, use_uva=True)
118
    valid_dataloader = dgl.dataloading.DataLoader(
119
            graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
120
121
            drop_last=False, num_workers=0, use_ddp=True,
            use_uva=True)
122
123
124

    durations = []
    for _ in range(10):
125
        model.train()
126
        t0 = time.time()
127
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
128
129
130
131
132
133
134
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label'][:, 0]
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
135
            if it % 20 == 0 and rank == 0:
136
137
138
139
                acc = MF.accuracy(y_hat, y)
                mem = torch.cuda.max_memory_allocated() / 1000000
                print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
        tt = time.time()
140

141
142
        if rank == 0:
            print(tt - t0)
143
144
145
146
147
148
149
150
151
152
153
154
155
        durations.append(tt - t0)

        model.eval()
        ys = []
        y_hats = []
        for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
            with torch.no_grad():
                x = blocks[0].srcdata['feat']
                ys.append(blocks[-1].dstdata['label'])
                y_hats.append(model.module(blocks, x))
        acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) / world_size
        dist.reduce(acc, 0)
        if rank == 0:
156
157
158
            print('Validation acc:', acc.item())
        dist.barrier()

159
160
    if rank == 0:
        print(np.mean(durations[4:]), np.std(durations[4:]))
161
162
163
164
165
166
    model.eval()
    with torch.no_grad():
        # since we do 1-layer at a time, use a very large batch size
        pred = model.module.inference(graph, device='cuda', batch_size=2**16)
        if rank == 0:
            acc = MF.accuracy(pred[test_idx], graph.ndata['label'][test_idx])
167
            print('Test acc:', acc.item())
168
169
170
171

if __name__ == '__main__':
    dataset = DglNodePropPredDataset('ogbn-products')
    graph, labels = dataset[0]
172
173
    graph.ndata['label'] = labels
    graph.create_formats_()     # must be called before mp.spawn().
174
175
    split_idx = dataset.get_idx_split()
    num_classes = dataset.num_classes
176
177
    # use all available GPUs
    n_procs = torch.cuda.device_count()
178
179
180

    # Tested with mp.spawn and fork.  Both worked and got 4s per epoch with 4 GPUs
    # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
181
    import torch.multiprocessing as mp
182
    mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs)