entity_sample_multi_gpu.py 7.93 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import argparse
2
import os
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4

import dgl
5
import torch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import torch.distributed as dist
import torch.multiprocessing as mp
8
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
9
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
12
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
14
15
from torch.nn.parallel import DistributedDataParallel
from torchmetrics.functional import accuracy

Mufei Li's avatar
Mufei Li committed
16

17
18
19
20
21
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )

39
40
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
41
42
43
44
        h = F.relu(
            self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
        )
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
45
        return h
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
46
47


48
49
def evaluate(model, labels, dataloader, inv_target):
    model.eval()
Mufei Li's avatar
Mufei Li committed
50
51
    eval_logits = []
    eval_seeds = []
52
53
54
55
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
57
58
59
60
61
62
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    num_seeds = len(eval_seeds)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
63
64
65
    loc_sum = accuracy(
        eval_logits.argmax(dim=1), labels[eval_seeds].cpu()
    ) * float(num_seeds)
66
67
    return torch.tensor([loc_sum.item(), float(num_seeds)])

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68

69
70
71
72
73
74
def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
    # define loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
75
76
77
78
79
80
81
82
83
    train_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=True,
        use_ddp=True,
    )
84
    # no separate validation subset, use train index instead for validation
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
85
86
87
88
89
90
91
92
93
    val_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=False,
        use_ddp=True,
    )
94
95
96
97
98
99
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
100
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
101
102
103
104
105
106
107
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # torchmetric accuracy defined as num_correct_labels / num_train_nodes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
108
109
110
111
        # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]
        loc_acc_split = evaluate(model, labels, val_loader, inv_target).to(
            device
        )
112
        dist.reduce(loc_acc_split, 0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
113
        if proc_id == 0:
114
            acc = loc_acc_split[0] / loc_acc_split[1]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
115
116
117
118
119
120
121
            print(
                "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
                    epoch, total_loss / (it + 1), acc.item()
                )
            )


122
123
124
125
126
def run(proc_id, nprocs, devices, g, data):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12345",
        world_size=nprocs,
        rank=proc_id,
    )
    (
        num_rels,
        num_classes,
        labels,
        train_idx,
        test_idx,
        target_idx,
        inv_target,
    ) = data
Mufei Li's avatar
Mufei Li committed
142
    labels = labels.to(device)
143
    inv_target = inv_target.to(device)
144
145
146
147
    # create RGCN model (distributed)
    in_size = g.num_nodes()
    out_size = num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
149
150
    model = DistributedDataParallel(
        model, device_ids=[device], output_device=device
    )
151
152
    # training + testing
    train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
153
154
155
156
157
158
159
160
161
162
163
164
    test_sampler = MultiLayerNeighborSampler(
        [-1, -1]
    )  # -1 for sampling all neighbors
    test_loader = DataLoader(
        g,
        target_idx[test_idx],
        test_sampler,
        device=device,
        batch_size=32,
        shuffle=False,
        use_ddp=True,
    )
165
166
    loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device)
    dist.reduce(loc_acc_split, 0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
167
    if proc_id == 0:
168
169
170
171
        acc = loc_acc_split[0] / loc_acc_split[1]
        print("Test accuracy {:.4f}".format(acc))
    # cleanup process group
    dist.destroy_process_group()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification with sampling (multi-gpu)"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="0",
        help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
        " e.g., 0,1,2,3.",
    )
Mufei Li's avatar
Mufei Li committed
191
    args = parser.parse_args()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
192
    devices = list(map(int, args.gpu.split(",")))
193
    nprocs = len(devices)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
194
195
196
197
198
    print(
        f"Training with DGL built-in RGCN module with sampling using",
        nprocs,
        f"GPU(s)",
    )
199
200

    # load and preprocess dataset at master(parent) process
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
201
    if args.dataset == "aifb":
202
        data = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
203
    elif args.dataset == "mutag":
204
        data = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
205
    elif args.dataset == "bgs":
206
        data = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
207
    elif args.dataset == "am":
208
209
        data = AMDataset()
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
210
        raise ValueError("Unknown dataset: {}".format(args.dataset))
211
212
213
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
214
215
216
    labels = g.nodes[category].data.pop("labels")
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
217
218
219
220
221
222
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
223
224
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)
225
226
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
227
228
229
    inv_target[target_idx] = torch.arange(
        0, target_idx.shape[0], dtype=inv_target.dtype
    )
230
231
232
233
234
    # avoid creating certain graph formats and train/test indexes in each sub-process to save momory
    g.create_formats_()
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    # thread limiting to avoid resource competition
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
235
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
236

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
237
238
239
240
241
242
243
244
245
    data = (
        num_rels,
        data.num_classes,
        labels,
        train_idx,
        test_idx,
        target_idx,
        inv_target,
    )
246
    mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)