"vscode:/vscode.git/clone" did not exist on "ee1ffe2e88595407209ff2a27494d6c107fac234"
entity_sample.py 5.39 KB
Newer Older
1
2
import torch
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
3
import torch.nn.functional as F
4
from torchmetrics.functional import accuracy
Mufei Li's avatar
Mufei Li committed
5
import dgl
6
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
7
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
8
9
from dgl.nn.pytorch import RelGraphConv
import argparse
Mufei Li's avatar
Mufei Li committed
10

11
12
13
14
15
16
17
18
19
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
        self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
20

21
22
23
24
25
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
        return h
26

27
def evaluate(model, label, dataloader, inv_target):
Mufei Li's avatar
Mufei Li committed
28
29
30
    model.eval()
    eval_logits = []
    eval_seeds = []
31
32
33
34
35
36
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
Mufei Li's avatar
Mufei Li committed
37
            eval_logits.append(logits.cpu().detach())
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    return  accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()

def train(device, g, target_idx, labels, train_mask, model):
    # define train idx, loss function and optimizer
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
    train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                              batch_size=100, shuffle=True)
    # no separate validation subset, use train index instead for validation
    val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                            batch_size=100, shuffle=False)
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        acc = evaluate(model, labels, val_loader, inv_target)
        print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
              . format(epoch, total_loss / (it+1), acc))
71

Mufei Li's avatar
Mufei Li committed
72
73
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
74
75
    parser.add_argument("--dataset", type=str, default="aifb",
                        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
Mufei Li's avatar
Mufei Li committed
76
    args = parser.parse_args()
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Training with DGL built-in RGCN module with sampling.')

    # load and preprocess dataset
    if args.dataset == 'aifb':
        data = AIFBDataset()
    elif args.dataset == 'mutag':
        data = MUTAGDataset()
    elif args.dataset == 'bgs':
        data = BGSDataset()
    elif args.dataset == 'am':
        data = AMDataset()
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
    labels = g.nodes[category].data.pop('labels').to(device)
    train_mask = g.nodes[category].data.pop('train_mask')
    test_mask = g.nodes[category].data.pop('test_mask')
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
    g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
    g.ndata['type_id'] = g.ndata.pop(dgl.NID)
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
    inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device)
108
109

    # create RGCN model
110
111
112
    in_size = g.num_nodes() # featureless with one-hot encoding
    out_size = data.num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)
113

114
115
116
117
118
119
120
    train(device, g, target_idx, labels, train_mask, model)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
    test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device, 
                             batch_size=32, shuffle=False)
    acc = evaluate(model, labels, test_loader, inv_target)
    print("Test accuracy {:.4f}".format(acc))