multi_gpu_node_classification.py 7.15 KB
Newer Older
1
import os
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
6
7
8
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
9
import dgl.nn as dglnn
10
from dgl.multiprocessing import shared_tensor
11
12
from dgl.data import AsNodePredDataset
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler
13
from ogb.nodeproppred import DglNodePropPredDataset
14
import tqdm
15
import argparse
16

17
class SAGE(nn.Module):
18
    def __init__(self, in_size, hid_size, out_size):
19
20
        super().__init__()
        self.layers = nn.ModuleList()
21
22
23
24
        # three-layer GraphSAGE-mean
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
25
        self.dropout = nn.Dropout(0.5)
26
27
        self.hid_size = hid_size
        self.out_size = out_size
28

29
30
31
    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
32
33
34
35
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
36
37
        return h

38
    def inference(self, g, device, batch_size, use_uva):
39
        g.ndata['h'] = g.ndata['feat']
40
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
41
        for l, layer in enumerate(self.layers):
42
            dataloader = DataLoader(
43
44
                g, torch.arange(g.num_nodes(), device=device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False,
45
46
                num_workers=0, use_ddp=True, use_uva=use_uva)
            # in order to prevent running out of GPU memory, allocate a
47
            # shared output tensor 'y' in host memory
48
            y = shared_tensor(
49
                    (g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size))
50
51
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
                    if dist.get_rank() == 0 else dataloader:
52
                x = blocks[0].srcdata['h']
53
54
55
56
57
58
                h = layer(blocks[0], x) # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                # non_blocking (with pinned memory) to accelerate data transfer
                y[output_nodes] = h.to(y.device, non_blocking=True)
59
60
            # make sure all GPUs are done writing to 'y'
            dist.barrier()
61
            g.ndata['h'] = y if use_uva else y.to(device)
62

63
64
        g.ndata.pop('h')
        return y
65

66
67
68
69
70
71
72
73
74
75
def evaluate(model, g, dataloader):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
            x = blocks[0].srcdata['feat']
            ys.append(blocks[-1].dstdata['label'])
            y_hats.append(model(blocks, x))
    return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def layerwise_infer(proc_id, device, g, nid, model, use_uva, batch_size = 2**16):
    model.eval()
    with torch.no_grad():
        pred = model.module.inference(g, device, batch_size, use_uva)
        pred = pred[nid]
        labels = g.ndata['label'][nid].to(pred.device)
    if proc_id == 0:
        acc = MF.accuracy(pred, labels)
        print("Test Accuracy {:.4f}".format(acc.item()))

def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva):
    sampler = NeighborSampler([10, 10, 10],
                              prefetch_node_feats=['feat'],
                              prefetch_labels=['label'])
    train_dataloader = DataLoader(g, train_idx, sampler, device=device,
                                  batch_size=1024, shuffle=True,
                                  drop_last=False, num_workers=0,
                                  use_ddp=True, use_uva=use_uva)
    val_dataloader = DataLoader(g, val_idx, sampler, device=device,
                                batch_size=1024, shuffle=True,
                                drop_last=False, num_workers=0,
                                use_ddp=True, use_uva=use_uva)
99
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
100
    for epoch in range(10):
101
        model.train()
102
        total_loss = 0
103
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
104
            x = blocks[0].srcdata['feat']
105
            y = blocks[-1].dstdata['label']
106
107
108
109
110
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
111
112
            total_loss += loss
        acc = evaluate(model, g, val_dataloader).to(device) / nprocs
113
        dist.reduce(acc, 0)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        if (proc_id == 0):
            print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
                  .format(epoch, total_loss / (it+1), acc.item()))

def run(proc_id, nprocs, devices, g, data, mode):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
    dist.init_process_group(backend="nccl", init_method='tcp://127.0.0.1:12345',
                            world_size=nprocs, rank=proc_id)
    out_size, train_idx, val_idx, test_idx = data
    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
    g = g.to(device if mode == 'puregpu' else 'cpu')
    # create GraphSAGE model (distributed)
    in_size = g.ndata['feat'].shape[1]
    model = SAGE(in_size, 256, out_size).to(device)
    model = DistributedDataParallel(model, device_ids=[device], output_device=device)
    # training + testing
    use_uva = (mode == 'mixed')
    train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva)
    layerwise_infer(proc_id, device, g, test_idx, model, use_uva)
    # cleanup process group
    dist.destroy_process_group()
139
140

if __name__ == '__main__':
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default='mixed', choices=['mixed', 'puregpu'],
                        help="Training mode. 'mixed' for CPU-GPU mixed training, "
                        "'puregpu' for pure-GPU training.")
    parser.add_argument("--gpu", type=str, default='0',
                        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
                        " e.g., 0,1,2,3.")
    args = parser.parse_args()
    devices = list(map(int, args.gpu.split(',')))
    nprocs = len(devices)
    assert torch.cuda.is_available(), f"Must have GPUs to enable multi-gpu training."
    print(f'Training in {args.mode} mode using {nprocs} GPU(s)')

    # load and preprocess dataset
    print('Loading data')
    dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
    g = dataset[0]
    # avoid creating certain graph formats in each sub-process to save momory
    g.create_formats_()
    # thread limiting to avoid resource competition
    os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // nprocs)
    data = dataset.num_classes, dataset.train_idx, dataset.val_idx, dataset.test_idx

    mp.spawn(run, args=(nprocs, devices, g, data, args.mode), nprocs=nprocs)