link.py 5.12 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
Differences compared to MichSchli/RelationPrediction
* Report raw metrics instead of filtered metrics.
* By default, we use uniform edge sampling instead of neighbor-based edge
  sampling used in author's code. In practice, we find it achieves similar MRR.
"""

import argparse
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader

from link_utils import preprocess, SubgraphIterator, calc_mrr
from model import RGCN

class LinkPredict(nn.Module):
    def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01):
        super(LinkPredict, self).__init__()
        self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd",
23
24
                         num_bases=num_bases, dropout=dropout, self_loop=True)
        self.dropout = nn.Dropout(dropout)
Mufei Li's avatar
Mufei Li committed
25
26
27
28
29
30
31
32
33
34
35
36
37
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.w_relation,
                                gain=nn.init.calculate_gain('relu'))

    def calc_score(self, embedding, triplets):
        # DistMult
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = th.sum(s * r * o, dim=1)
        return score

38
39
    def forward(self, g, nids):
        return self.dropout(self.rgcn(g, nids=nids))
Mufei Li's avatar
Mufei Li committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    def regularization_loss(self, embedding):
        return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2))

    def get_loss(self, embed, triplets, labels):
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

def main(args):
    data = FB15k237Dataset(reverse=False)
    graph = data[0]
    num_nodes = graph.num_nodes()
    num_rels = data.num_rels

    train_g, test_g = preprocess(graph, num_rels)
58
    test_nids = th.arange(0, num_nodes)
Mufei Li's avatar
Mufei Li committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    test_mask = graph.edata['test_mask']
    subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler)
    dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])

    # Prepare data for metric computation
    src, dst = graph.edges()
    triplets = th.stack([src, graph.edata['etype'], dst], dim=1)

    model = LinkPredict(num_nodes, num_rels)
    optimizer = th.optim.Adam(model.parameters(), lr=1e-2)

    if args.gpu >= 0 and th.cuda.is_available():
        device = th.device(args.gpu)
    else:
        device = th.device('cpu')
    model = model.to(device)

    best_mrr = 0
    model_state_file = 'model_state.pth'
    for epoch, batch_data in enumerate(dataloader):
        model.train()

81
        g, train_nids, edges, labels = batch_data
Mufei Li's avatar
Mufei Li committed
82
        g = g.to(device)
83
84
        train_nids = train_nids.to(device)
        edges = edges.to(device)
Mufei Li's avatar
Mufei Li committed
85
86
        labels = labels.to(device)

87
88
        embed = model(g, train_nids)
        loss = model.get_loss(embed, edges, labels)
Mufei Li's avatar
Mufei Li committed
89
90
91
92
93
94
95
96
97
98
99
100
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
        optimizer.step()

        print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))

        if (epoch + 1) % 500 == 0:
            # perform validation on CPU because full graph is too large
            model = model.cpu()
            model.eval()
            print("start eval")
101
            embed = model(test_g, test_nids)
Mufei Li's avatar
Mufei Li committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
                           batch_size=500, eval_p=args.eval_protocol)
            # save best model
            if best_mrr < mrr:
                best_mrr = mrr
                th.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)

            model = model.to(device)

    print("Start testing:")
    # use best model checkpoint
    checkpoint = th.load(model_state_file)
    model = model.cpu() # test on CPU
    model.eval()
    model.load_state_dict(checkpoint['state_dict'])
    print("Using best epoch: {}".format(checkpoint['epoch']))
118
    embed = model(test_g, test_nids)
Mufei Li's avatar
Mufei Li committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    calc_mrr(embed, model.w_relation, test_mask, triplets,
             batch_size=500, eval_p=args.eval_protocol)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for link prediction')
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--eval-protocol", type=str, default='filtered',
                        choices=['filtered', 'raw'],
                        help="Whether to use 'filtered' or 'raw' MRR for evaluation")
    parser.add_argument("--edge-sampler", type=str, default='uniform',
                        choices=['uniform', 'neighbor'],
                        help="Type of edge sampler: 'uniform' or 'neighbor'"
                             "The original implementation uses neighbor sampler.")

    args = parser.parse_args()
    print(args)
    main(args)