entity_sample_multi_gpu.py 8.16 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import argparse
2
import os
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4

import dgl
5
import torch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import torch.distributed as dist
import torch.multiprocessing as mp
8
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
9
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
12
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
14
15
from torch.nn.parallel import DistributedDataParallel
from torchmetrics.functional import accuracy

Mufei Li's avatar
Mufei Li committed
16

17
18
19
20
21
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )

39
40
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
41
42
43
44
        h = F.relu(
            self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
        )
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
45
        return h
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
46
47


48
def evaluate(model, labels, num_classes, dataloader, inv_target):
49
    model.eval()
Mufei Li's avatar
Mufei Li committed
50
51
    eval_logits = []
    eval_seeds = []
52
53
54
55
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
57
58
59
60
61
62
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    num_seeds = len(eval_seeds)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
63
    loc_sum = accuracy(
64
65
66
67
        eval_logits.argmax(dim=1),
        labels[eval_seeds].cpu(),
        task="multiclass",
        num_classes=num_classes,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
    ) * float(num_seeds)
69
70
    return torch.tensor([loc_sum.item(), float(num_seeds)])

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
71

72
73
74
75
76
77
78
79
80
81
82
def train(
    proc_id,
    device,
    g,
    target_idx,
    labels,
    num_classes,
    train_idx,
    inv_target,
    model,
):
83
84
85
86
87
    # define loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
89
90
91
92
93
94
95
96
    train_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=True,
        use_ddp=True,
    )
97
    # no separate validation subset, use train index instead for validation
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
98
99
100
101
102
103
104
105
106
    val_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=False,
        use_ddp=True,
    )
107
108
109
110
111
112
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
113
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
114
115
116
117
118
119
120
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # torchmetric accuracy defined as num_correct_labels / num_train_nodes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
        # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]
122
123
124
        loc_acc_split = evaluate(
            model, labels, num_classes, val_loader, inv_target
        ).to(device)
125
        dist.reduce(loc_acc_split, 0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
126
        if proc_id == 0:
127
            acc = loc_acc_split[0] / loc_acc_split[1]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
128
129
130
131
132
133
134
            print(
                "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
                    epoch, total_loss / (it + 1), acc.item()
                )
            )


135
136
137
138
139
def run(proc_id, nprocs, devices, g, data):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
    (
        num_rels,
        num_classes,
        labels,
        train_idx,
        test_idx,
        target_idx,
        inv_target,
    ) = data
Mufei Li's avatar
Mufei Li committed
155
    labels = labels.to(device)
156
    inv_target = inv_target.to(device)
157
158
    # create RGCN model (distributed)
    in_size = g.num_nodes()
159
    model = RGCN(in_size, 16, num_classes, num_rels).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
160
161
162
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
163
    # training + testing
164
165
166
167
168
169
170
171
172
173
174
    train(
        proc_id,
        device,
        g,
        target_idx,
        labels,
        num_classes,
        train_idx,
        inv_target,
        model,
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
175
176
177
178
179
180
181
182
183
184
185
186
    test_sampler = MultiLayerNeighborSampler(
        [-1, -1]
    )  # -1 for sampling all neighbors
    test_loader = DataLoader(
        g,
        target_idx[test_idx],
        test_sampler,
        device=device,
        batch_size=32,
        shuffle=False,
        use_ddp=True,
    )
187
188
189
    loc_acc_split = evaluate(
        model, labels, num_classes, test_loader, inv_target
    ).to(device)
190
    dist.reduce(loc_acc_split, 0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191
    if proc_id == 0:
192
193
194
195
        acc = loc_acc_split[0] / loc_acc_split[1]
        print("Test accuracy {:.4f}".format(acc))
    # cleanup process group
    dist.destroy_process_group()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification with sampling (multi-gpu)"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
Mufei Li's avatar
Mufei Li committed
215
    args = parser.parse_args()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
216
    devices = list(map(int, args.gpu.split(",")))
217
    nprocs = len(devices)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
218
219
220
221
222
    print(
        f"Training with DGL built-in RGCN module with sampling using",
        nprocs,
        f"GPU(s)",
    )
223
224

    # load and preprocess dataset at master(parent) process
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
225
    if args.dataset == "aifb":
226
        data = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
227
    elif args.dataset == "mutag":
228
        data = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
229
    elif args.dataset == "bgs":
230
        data = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
231
    elif args.dataset == "am":
232
233
        data = AMDataset()
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
234
        raise ValueError("Unknown dataset: {}".format(args.dataset))
235
236
237
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
238
239
240
    labels = g.nodes[category].data.pop("labels")
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
241
242
243
244
245
246
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
247
248
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)
249
250
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
251
252
253
    inv_target[target_idx] = torch.arange(
        0, target_idx.shape[0], dtype=inv_target.dtype
    )
254
255
256
257
258
    # avoid creating certain graph formats and train/test indexes in each sub-process to save momory
    g.create_formats_()
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    # thread limiting to avoid resource competition
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
259
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
260

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
261
262
263
264
265
266
267
268
269
    data = (
        num_rels,
        data.num_classes,
        labels,
        train_idx,
        test_idx,
        target_idx,
        inv_target,
    )
270
    mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)