multi_gpu_node_classification.py 7.49 KB
Newer Older
1
import argparse
2
import os
3

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl.nn as dglnn

6
import torch
7
8
import torch.distributed as dist
import torch.multiprocessing as mp
9
10
11
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
12
import tqdm
13
from dgl.data import AsNodePredDataset
14
15
16
17
18
19
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
from dgl.multiprocessing import shared_tensor
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
21
from ogb.nodeproppred import DglNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
22

23

24
class SAGE(nn.Module):
25
    def __init__(self, in_size, hid_size, out_size):
26
27
        super().__init__()
        self.layers = nn.ModuleList()
28
        # three-layer GraphSAGE-mean
29
30
31
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
32
        self.dropout = nn.Dropout(0.5)
33
34
        self.hid_size = hid_size
        self.out_size = out_size
35

36
37
38
    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
39
40
41
42
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
43
44
        return h

45
    def inference(self, g, device, batch_size, use_uva):
46
47
        g.ndata["h"] = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["h"])
48
        for l, layer in enumerate(self.layers):
49
            dataloader = DataLoader(
50
51
52
53
54
55
56
57
58
59
60
                g,
                torch.arange(g.num_nodes(), device=device),
                sampler,
                device=device,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=0,
                use_ddp=True,
                use_uva=use_uva,
            )
61
            # in order to prevent running out of GPU memory, allocate a
62
            # shared output tensor 'y' in host memory
63
            y = shared_tensor(
64
65
66
67
68
69
70
71
72
73
74
75
                (
                    g.num_nodes(),
                    self.hid_size
                    if l != len(self.layers) - 1
                    else self.out_size,
                )
            )
            for input_nodes, output_nodes, blocks in (
                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader
            ):
                x = blocks[0].srcdata["h"]
                h = layer(blocks[0], x)  # len(blocks) = 1
76
77
78
79
80
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                # non_blocking (with pinned memory) to accelerate data transfer
                y[output_nodes] = h.to(y.device, non_blocking=True)
81
82
            # make sure all GPUs are done writing to 'y'
            dist.barrier()
83
            g.ndata["h"] = y if use_uva else y.to(device)
84

85
        g.ndata.pop("h")
86
        return y
87

88

89
90
91
92
93
94
def evaluate(model, g, dataloader):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
95
96
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
97
98
            y_hats.append(model(blocks, x))
    return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
99

100
101
102
103

def layerwise_infer(
    proc_id, device, g, nid, model, use_uva, batch_size=2**16
):
104
105
106
107
    model.eval()
    with torch.no_grad():
        pred = model.module.inference(g, device, batch_size, use_uva)
        pred = pred[nid]
108
        labels = g.ndata["label"][nid].to(pred.device)
109
110
111
112
    if proc_id == 0:
        acc = MF.accuracy(pred, labels)
        print("Test Accuracy {:.4f}".format(acc.item()))

113

114
def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    sampler = NeighborSampler(
        [10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
    )
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_ddp=True,
        use_uva=use_uva,
    )
    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_ddp=True,
        use_uva=use_uva,
    )
142
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
143
    for epoch in range(10):
144
        model.train()
145
        total_loss = 0
146
147
148
149
150
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]
151
152
153
154
155
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
156
157
            total_loss += loss
        acc = evaluate(model, g, val_dataloader).to(device) / nprocs
158
        dist.reduce(acc, 0)
159
160
161
162
163
164
165
        if proc_id == 0:
            print(
                "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                    epoch, total_loss / (it + 1), acc.item()
                )
            )

166
167
168
169
170
171

def run(proc_id, nprocs, devices, g, data, mode):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
172
173
174
175
176
177
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
178
179
180
    out_size, train_idx, val_idx, test_idx = data
    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
181
    g = g.to(device if mode == "puregpu" else "cpu")
182
    # create GraphSAGE model (distributed)
183
    in_size = g.ndata["feat"].shape[1]
184
    model = SAGE(in_size, 256, out_size).to(device)
185
186
187
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
188
    # training + testing
189
    use_uva = mode == "mixed"
190
191
192
193
    train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva)
    layerwise_infer(proc_id, device, g, test_idx, model, use_uva)
    # cleanup process group
    dist.destroy_process_group()
194

195
196

if __name__ == "__main__":
197
    parser = argparse.ArgumentParser()
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["mixed", "puregpu"],
        help="Training mode. 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
212
    args = parser.parse_args()
213
    devices = list(map(int, args.gpu.split(",")))
214
    nprocs = len(devices)
215
216
217
218
    assert (
        torch.cuda.is_available()
    ), f"Must have GPUs to enable multi-gpu training."
    print(f"Training in {args.mode} mode using {nprocs} GPU(s)")
219
220

    # load and preprocess dataset
221
222
    print("Loading data")
    dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
223
224
225
226
    g = dataset[0]
    # avoid creating certain graph formats in each sub-process to save momory
    g.create_formats_()
    # thread limiting to avoid resource competition
227
228
229
230
231
232
233
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    data = (
        dataset.num_classes,
        dataset.train_idx,
        dataset.val_idx,
        dataset.test_idx,
    )
234
235

    mp.spawn(run, args=(nprocs, devices, g, data, args.mode), nprocs=nprocs)