multi_gpu_node_classification.py 7.49 KB
Newer Older
1
import argparse
2
import os
3

4
import torch
5
6
import torch.distributed as dist
import torch.multiprocessing as mp
7
8
9
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
10
11
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
12
from torch.nn.parallel import DistributedDataParallel
13

14
import dgl.nn as dglnn
15
from dgl.data import AsNodePredDataset
16
17
18
19
20
21
22
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
from dgl.multiprocessing import shared_tensor

23

24
class SAGE(nn.Module):
25
    def __init__(self, in_size, hid_size, out_size):
26
27
        super().__init__()
        self.layers = nn.ModuleList()
28
        # three-layer GraphSAGE-mean
29
30
31
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
32
        self.dropout = nn.Dropout(0.5)
33
34
        self.hid_size = hid_size
        self.out_size = out_size
35

36
37
38
    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
39
40
41
42
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
43
44
        return h

45
    def inference(self, g, device, batch_size, use_uva):
46
47
        g.ndata["h"] = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["h"])
48
        for l, layer in enumerate(self.layers):
49
            dataloader = DataLoader(
50
51
52
53
54
55
56
57
58
59
60
                g,
                torch.arange(g.num_nodes(), device=device),
                sampler,
                device=device,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=0,
                use_ddp=True,
                use_uva=use_uva,
            )
61
            # in order to prevent running out of GPU memory, allocate a
62
            # shared output tensor 'y' in host memory
63
            y = shared_tensor(
64
65
66
67
68
69
70
71
72
73
74
75
                (
                    g.num_nodes(),
                    self.hid_size
                    if l != len(self.layers) - 1
                    else self.out_size,
                )
            )
            for input_nodes, output_nodes, blocks in (
                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader
            ):
                x = blocks[0].srcdata["h"]
                h = layer(blocks[0], x)  # len(blocks) = 1
76
77
78
79
80
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                # non_blocking (with pinned memory) to accelerate data transfer
                y[output_nodes] = h.to(y.device, non_blocking=True)
81
82
            # make sure all GPUs are done writing to 'y'
            dist.barrier()
83
            g.ndata["h"] = y if use_uva else y.to(device)
84

85
        g.ndata.pop("h")
86
        return y
87

88

89
90
91
92
93
94
def evaluate(model, g, dataloader):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
95
96
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
97
98
            y_hats.append(model(blocks, x))
    return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
99

100
101
102
103

def layerwise_infer(
    proc_id, device, g, nid, model, use_uva, batch_size=2**16
):
104
105
106
107
    model.eval()
    with torch.no_grad():
        pred = model.module.inference(g, device, batch_size, use_uva)
        pred = pred[nid]
108
        labels = g.ndata["label"][nid].to(pred.device)
109
110
111
112
    if proc_id == 0:
        acc = MF.accuracy(pred, labels)
        print("Test Accuracy {:.4f}".format(acc.item()))

113

114
def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    sampler = NeighborSampler(
        [10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
    )
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_ddp=True,
        use_uva=use_uva,
    )
    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_ddp=True,
        use_uva=use_uva,
    )
142
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
143
    for epoch in range(10):
144
        model.train()
145
        total_loss = 0
146
147
148
149
150
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]
151
152
153
154
155
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
156
157
            total_loss += loss
        acc = evaluate(model, g, val_dataloader).to(device) / nprocs
158
        dist.reduce(acc, 0)
159
160
161
162
163
164
165
        if proc_id == 0:
            print(
                "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                    epoch, total_loss / (it + 1), acc.item()
                )
            )

166
167
168
169
170
171

def run(proc_id, nprocs, devices, g, data, mode):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
172
173
174
175
176
177
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
178
179
180
    out_size, train_idx, val_idx, test_idx = data
    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
181
    g = g.to(device if mode == "puregpu" else "cpu")
182
    # create GraphSAGE model (distributed)
183
    in_size = g.ndata["feat"].shape[1]
184
    model = SAGE(in_size, 256, out_size).to(device)
185
186
187
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
188
    # training + testing
189
    use_uva = mode == "mixed"
190
191
192
193
    train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva)
    layerwise_infer(proc_id, device, g, test_idx, model, use_uva)
    # cleanup process group
    dist.destroy_process_group()
194

195
196

if __name__ == "__main__":
197
    parser = argparse.ArgumentParser()
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["mixed", "puregpu"],
        help="Training mode. 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
212
    args = parser.parse_args()
213
    devices = list(map(int, args.gpu.split(",")))
214
    nprocs = len(devices)
215
216
217
218
    assert (
        torch.cuda.is_available()
    ), f"Must have GPUs to enable multi-gpu training."
    print(f"Training in {args.mode} mode using {nprocs} GPU(s)")
219
220

    # load and preprocess dataset
221
222
    print("Loading data")
    dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
223
224
225
226
    g = dataset[0]
    # avoid creating certain graph formats in each sub-process to save momory
    g.create_formats_()
    # thread limiting to avoid resource competition
227
228
229
230
231
232
233
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    data = (
        dataset.num_classes,
        dataset.train_idx,
        dataset.val_idx,
        dataset.test_idx,
    )
234
235

    mp.spawn(run, args=(nprocs, devices, g, data, args.mode), nprocs=nprocs)