multi_gpu_node_classification.py 7.56 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.optim
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
9
10
from dgl.utils import pin_memory_inplace, unpin_memory_inplace
from dgl.multiprocessing import shared_tensor
11
12
13
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
14
import tqdm
15

16

17
18
19
20
21
22
23
24
class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(0.5)
25
26
        self.n_hidden = n_hidden
        self.n_classes = n_classes
27

28
29
30
31
32
33
34
    def _forward_layer(self, l, block, x):
        h = self.layers[l](block, x)
        if l != len(self.layers) - 1:
            h = F.relu(h)
            h = self.dropout(h)
        return h

35
36
37
    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
38
            h = self._forward_layer(l, blocks[l], h)
39
40
        return h

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def inference(self, g, device, batch_size):
        """
        Perform inference in layer-major order rather than batch-major order.
        That is, infer the first layer for the entire graph, and store the
        intermediate values h_0, before infering the second layer to generate
        h_1. This is done for two reasons: 1) it limits the effect of node
        degree on the amount of memory used as it only proccesses 1-hop
        neighbors at a time, and 2) it reduces the total amount of computation
        required as each node is only processed once per layer.

        Parameters
        ----------
            g : DGLGraph
                The graph to perform inference on.
            device : context
                The device this process should use for inference
            batch_size : int
                The number of items to collect in a batch.

        Returns
        -------
            tensor
                The predictions for all nodes in the graph.
        """
65
66
        g.ndata['h'] = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
67
        dataloader = dgl.dataloading.DataLoader(
68
69
70
                g, torch.arange(g.num_nodes(), device=device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False,
                num_workers=0, use_ddp=True, use_uva=True)
71

72
        for l, layer in enumerate(self.layers):
73
74
75
76
            # in order to prevent running out of GPU memory, we allocate a
            # shared output tensor 'y' in host memory, pin it to allow UVA
            # access from each GPU during forward propagation.
            y = shared_tensor(
77
78
                    (g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes))
            pin_memory_inplace(y)
79
80
81

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
                    if dist.get_rank() == 0 else dataloader:
82
                x = blocks[0].srcdata['h']
83
84
85
86
87
88
89
90
91
92
93
94
95
                h = self._forward_layer(l, blocks[0], x)
                y[output_nodes] = h.to(y.device)
            # make sure all GPUs are done writing to 'y'
            dist.barrier()
            if l > 0:
                unpin_memory_inplace(g.ndata['h'])
            if l + 1 < len(self.layers):
                # assign the output features of this layer as the new input
                # features for the next layer
                g.ndata['h'] = y
            else:
                # remove the intermediate data from the graph
                g.ndata.pop('h')
96
97
98
        return y


99
def train(rank, world_size, graph, num_classes, split_idx):
100
101
102
103
104
105
106
107
    torch.cuda.set_device(rank)
    dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank)

    model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda()
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
108

109
    # move ids to GPU
110
111
    train_idx = train_idx.to('cuda')
    valid_idx = valid_idx.to('cuda')
112
    test_idx = test_idx.to('cuda')
113

114
115
116
117
    # For training, each process/GPU will get a subset of the
    # train_idx/valid_idx, and generate mini-batches indepednetly. This allows
    # the only communication neccessary in training to be the all-reduce for
    # the gradients performed by the DDP wrapper (created above).
118
    sampler = dgl.dataloading.NeighborSampler(
119
            [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
120
    train_dataloader = dgl.dataloading.DataLoader(
121
            graph, train_idx, sampler,
122
            device='cuda', batch_size=1024, shuffle=True, drop_last=False,
123
            num_workers=0, use_ddp=True, use_uva=True)
124
    valid_dataloader = dgl.dataloading.DataLoader(
125
            graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
126
127
            drop_last=False, num_workers=0, use_ddp=True,
            use_uva=True)
128
129
130

    durations = []
    for _ in range(10):
131
        model.train()
132
        t0 = time.time()
133
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
134
135
136
137
138
139
140
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label'][:, 0]
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
141
            if it % 20 == 0 and rank == 0:
142
143
144
145
                acc = MF.accuracy(y_hat, y)
                mem = torch.cuda.max_memory_allocated() / 1000000
                print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
        tt = time.time()
146

147
148
        if rank == 0:
            print(tt - t0)
149
150
151
152
153
154
155
156
157
158
159
160
161
        durations.append(tt - t0)

        model.eval()
        ys = []
        y_hats = []
        for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
            with torch.no_grad():
                x = blocks[0].srcdata['feat']
                ys.append(blocks[-1].dstdata['label'])
                y_hats.append(model.module(blocks, x))
        acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) / world_size
        dist.reduce(acc, 0)
        if rank == 0:
162
163
164
            print('Validation acc:', acc.item())
        dist.barrier()

165
166
    if rank == 0:
        print(np.mean(durations[4:]), np.std(durations[4:]))
167
168
169
170
171
172
    model.eval()
    with torch.no_grad():
        # since we do 1-layer at a time, use a very large batch size
        pred = model.module.inference(graph, device='cuda', batch_size=2**16)
        if rank == 0:
            acc = MF.accuracy(pred[test_idx], graph.ndata['label'][test_idx])
173
            print('Test acc:', acc.item())
174
175
176
177

if __name__ == '__main__':
    dataset = DglNodePropPredDataset('ogbn-products')
    graph, labels = dataset[0]
178
179
    graph.ndata['label'] = labels
    graph.create_formats_()     # must be called before mp.spawn().
180
181
    split_idx = dataset.get_idx_split()
    num_classes = dataset.num_classes
182
183
    # use all available GPUs
    n_procs = torch.cuda.device_count()
184
185
186

    # Tested with mp.spawn and fork.  Both worked and got 4s per epoch with 4 GPUs
    # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
187
    import torch.multiprocessing as mp
188
    mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs)