entity_sample_multi_gpu.py 7.41 KB
Newer Older
1
2
3
import os
import torch
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
4
5
import torch.nn.functional as F
from torchmetrics.functional import accuracy
6
7
import torch.multiprocessing as mp
import torch.distributed as dist
Mufei Li's avatar
Mufei Li committed
8
from torch.nn.parallel import DistributedDataParallel
9
10
11
12
13
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv
import argparse
Mufei Li's avatar
Mufei Li committed
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
        self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
        
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
        return h
    
def evaluate(model, labels, dataloader, inv_target):
    model.eval()
Mufei Li's avatar
Mufei Li committed
33
34
    eval_logits = []
    eval_seeds = []
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    num_seeds = len(eval_seeds)
    loc_sum = accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()) * float(num_seeds)
    return torch.tensor([loc_sum.item(), float(num_seeds)])

def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
    # define loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
    train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                              batch_size=100, shuffle=True, use_ddp=True)
    # no separate validation subset, use train index instead for validation
    val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                            batch_size=100, shuffle=False, use_ddp=True)
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # torchmetric accuracy defined as num_correct_labels / num_train_nodes
        # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes] 
        loc_acc_split = evaluate(model, labels, val_loader, inv_target).to(device)
        dist.reduce(loc_acc_split, 0)
        if (proc_id == 0):
            acc = loc_acc_split[0] / loc_acc_split[1]
            print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
                  . format(epoch, total_loss / (it+1), acc.item()))
            
def run(proc_id, nprocs, devices, g, data):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
    dist.init_process_group(backend="nccl", init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=proc_id)
    num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = data
Mufei Li's avatar
Mufei Li committed
89
    labels = labels.to(device)
90
    inv_target = inv_target.to(device)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    # create RGCN model (distributed)
    in_size = g.num_nodes()
    out_size = num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)
    model = DistributedDataParallel(model, device_ids=[device], output_device=device)
    # training + testing
    train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model)
    test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
    test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device, 
                             batch_size=32, shuffle=False, use_ddp=True)
    loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device)
    dist.reduce(loc_acc_split, 0)
    if (proc_id == 0):
        acc = loc_acc_split[0] / loc_acc_split[1]
        print("Test accuracy {:.4f}".format(acc))
    # cleanup process group
    dist.destroy_process_group()
        
Mufei Li's avatar
Mufei Li committed
109
if __name__ == '__main__':
110
111
112
    parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling (multi-gpu)')
    parser.add_argument("--dataset", type=str, default="aifb",
                        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
Mufei Li's avatar
Mufei Li committed
113
    parser.add_argument("--gpu", type=str, default='0',
114
115
                           help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
                                " e.g., 0,1,2,3.")
Mufei Li's avatar
Mufei Li committed
116
117
    args = parser.parse_args()
    devices = list(map(int, args.gpu.split(',')))
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    nprocs = len(devices)
    print(f'Training with DGL built-in RGCN module with sampling using', nprocs, f'GPU(s)')

    # load and preprocess dataset at master(parent) process
    if args.dataset == 'aifb':
        data = AIFBDataset()
    elif args.dataset == 'mutag':
        data = MUTAGDataset()
    elif args.dataset == 'bgs':
        data = BGSDataset()
    elif args.dataset == 'am':
        data = AMDataset()
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
    labels = g.nodes[category].data.pop('labels')
    train_mask = g.nodes[category].data.pop('train_mask')
    test_mask = g.nodes[category].data.pop('test_mask')
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
    g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
    g.ndata['type_id'] = g.ndata.pop(dgl.NID)
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)
    inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype)
    # avoid creating certain graph formats and train/test indexes in each sub-process to save momory
    g.create_formats_()
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    # thread limiting to avoid resource competition
    os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // nprocs)

    data = num_rels, data.num_classes, labels, train_idx, test_idx, target_idx, inv_target
    mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)

Mufei Li's avatar
Mufei Li committed
159