link.py 11.1 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
2
3
import numpy as np
import torch
Mufei Li's avatar
Mufei Li committed
4
5
import torch.nn as nn
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
import tqdm
Mufei Li's avatar
Mufei Li committed
7
8
from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader
9
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10

11
12
13
14
15
16

# for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
    src, dst = g.edges()
    sub_src = src[mask]
    sub_dst = dst[mask]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
    sub_rel = g.edata["etype"][mask]
18
19

    if bidirected:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
21
22
        sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat(
            [sub_dst, sub_src]
        )
23
24
25
26
27
28
        sub_rel = torch.cat([sub_rel, sub_rel + num_rels])

    sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
    sub_g.edata[dgl.ETYPE] = sub_rel
    return sub_g

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
29

30
31
32
33
34
35
36
37
class GlobalUniform:
    def __init__(self, g, sample_size):
        self.sample_size = sample_size
        self.eids = np.arange(g.num_edges())

    def sample(self):
        return torch.from_numpy(np.random.choice(self.eids, self.sample_size))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
38

39
class NegativeSampler:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
40
    def __init__(self, k=10):  # negative sampling rate = 10
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        self.k = k

    def sample(self, pos_samples, num_nodes):
        batch_size = len(pos_samples)
        neg_batch_size = batch_size * self.k
        neg_samples = np.tile(pos_samples, (self.k, 1))

        values = np.random.randint(num_nodes, size=neg_batch_size)
        choices = np.random.uniform(size=neg_batch_size)
        subj = choices > 0.5
        obj = choices <= 0.5
        neg_samples[subj, 0] = values[subj]
        neg_samples[obj, 2] = values[obj]
        samples = np.concatenate((pos_samples, neg_samples))

        # binary labels indicating positive and negative samples
        labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
        labels[:batch_size] = 1

        return torch.from_numpy(samples), torch.from_numpy(labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62

63
64
65
66
67
68
69
70
class SubgraphIterator:
    def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
        self.g = g
        self.num_rels = num_rels
        self.sample_size = sample_size
        self.num_epochs = num_epochs
        self.pos_sampler = GlobalUniform(g, sample_size)
        self.neg_sampler = NegativeSampler()
Mufei Li's avatar
Mufei Li committed
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def __len__(self):
        return self.num_epochs

    def __getitem__(self, i):
        eids = self.pos_sampler.sample()
        src, dst = self.g.find_edges(eids)
        src, dst = src.numpy(), dst.numpy()
        rel = self.g.edata[dgl.ETYPE][eids].numpy()

        # relabel nodes to have consecutive node IDs
        uniq_v, edges = np.unique((src, dst), return_inverse=True)
        num_nodes = len(uniq_v)
        # edges is the concatenation of src, dst with relabeled ID
        src, dst = np.reshape(edges, (2, -1))
        relabeled_data = np.stack((src, rel, dst)).transpose()

        samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)

        # use only half of the positive edges
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
91
92
93
94
95
        chosen_ids = np.random.choice(
            np.arange(self.sample_size),
            size=int(self.sample_size / 2),
            replace=False,
        )
96
97
98
99
100
101
102
        src = src[chosen_ids]
        dst = dst[chosen_ids]
        rel = rel[chosen_ids]
        src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
        rel = np.concatenate((rel, rel + self.num_rels))
        sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
        sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
103
        sub_g.edata["norm"] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
104
105
106
107
        uniq_v = torch.from_numpy(uniq_v).view(-1).long()

        return sub_g, uniq_v, samples, labels

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
108

109
110
111
112
113
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, num_rels):
        super().__init__()
        # two-layer RGCN
        self.emb = nn.Embedding(num_nodes, h_dim)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="bdd",
            num_bases=100,
            self_loop=True,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="bdd",
            num_bases=100,
            self_loop=True,
        )
130
131
132
133
        self.dropout = nn.Dropout(0.2)

    def forward(self, g, nids):
        x = self.emb(nids)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
134
        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"]))
135
        h = self.dropout(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
137
        return self.dropout(h)
Mufei Li's avatar
Mufei Li committed
138

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139

Mufei Li's avatar
Mufei Li committed
140
class LinkPredict(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
141
    def __init__(self, num_nodes, num_rels, h_dim=500, reg_param=0.01):
142
143
        super().__init__()
        self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
Mufei Li's avatar
Mufei Li committed
144
        self.reg_param = reg_param
145
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
147
148
        nn.init.xavier_uniform_(
            self.w_relation, gain=nn.init.calculate_gain("relu")
        )
Mufei Li's avatar
Mufei Li committed
149
150

    def calc_score(self, embedding, triplets):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
151
152
153
        s = embedding[triplets[:, 0]]
        r = self.w_relation[triplets[:, 1]]
        o = embedding[triplets[:, 2]]
154
        score = torch.sum(s * r * o, dim=1)
Mufei Li's avatar
Mufei Li committed
155
156
        return score

157
    def forward(self, g, nids):
158
        return self.rgcn(g, nids)
Mufei Li's avatar
Mufei Li committed
159
160

    def regularization_loss(self, embedding):
161
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))
Mufei Li's avatar
Mufei Li committed
162
163
164
165
166
167
168
169

    def get_loss(self, embed, triplets, labels):
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
170
171
172
173

def filter(
    triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True
):
174
175
176
177
178
179
180
181
    """Get candidate heads or tails to score"""
    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
    # Add the ground truth node first
    if filter_o:
        candidate_nodes = [target_o]
    else:
        candidate_nodes = [target_s]
    for e in range(num_nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
182
183
184
        triplet = (
            (target_s, target_r, e) if filter_o else (e, target_r, target_o)
        )
185
186
187
188
        # Do not consider a node if it leads to a real triplet
        if triplet not in triplets_to_filter:
            candidate_nodes.append(e)
    return torch.LongTensor(candidate_nodes)
Mufei Li's avatar
Mufei Li committed
189

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190
191
192
193

def perturb_and_get_filtered_rank(
    emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True
):
194
195
196
197
198
199
200
    """Perturb subject or object in the triplets"""
    num_nodes = emb.shape[0]
    ranks = []
    for idx in tqdm.tqdm(range(test_size), desc="Evaluate"):
        target_s = s[idx]
        target_r = r[idx]
        target_o = o[idx]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
201
202
203
204
205
206
207
208
        candidate_nodes = filter(
            triplets_to_filter,
            target_s,
            target_r,
            target_o,
            num_nodes,
            filter_o=filter_o,
        )
209
210
211
212
213
214
215
216
217
218
        if filter_o:
            emb_s = emb[target_s]
            emb_o = emb[candidate_nodes]
        else:
            emb_s = emb[candidate_nodes]
            emb_o = emb[target_o]
        target_idx = 0
        emb_r = w[target_r]
        emb_triplet = emb_s * emb_r * emb_o
        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
Mufei Li's avatar
Mufei Li committed
219

220
221
222
223
        _, indices = torch.sort(scores, descending=True)
        rank = int((indices == target_idx).nonzero())
        ranks.append(rank)
    return torch.LongTensor(ranks)
Mufei Li's avatar
Mufei Li committed
224

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
225

226
def calc_mrr(emb, w, mask, triplets_to_filter, batch_size=100, filter=True):
227
    with torch.no_grad():
228
        test_triplets = triplets_to_filter[mask]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
229
        s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2]
230
        test_size = len(s)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
231
232
233
234
235
236
237
238
239
        triplets_to_filter = {
            tuple(triplet) for triplet in triplets_to_filter.tolist()
        }
        ranks_s = perturb_and_get_filtered_rank(
            emb, w, s, r, o, test_size, triplets_to_filter, filter_o=False
        )
        ranks_o = perturb_and_get_filtered_rank(
            emb, w, s, r, o, test_size, triplets_to_filter
        )
240
        ranks = torch.cat([ranks_s, ranks_o])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
241
        ranks += 1  # change to 1-indexed
242
243
        mrr = torch.mean(1.0 / ranks.float()).item()
    return mrr
Mufei Li's avatar
Mufei Li committed
244

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
245
246
247
248
249

def train(
    dataloader,
    test_g,
    test_nids,
250
    val_mask,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
251
252
253
254
255
    triplets,
    device,
    model_state_file,
    model,
):
256
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
Mufei Li's avatar
Mufei Li committed
257
    best_mrr = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
258
    for epoch, batch_data in enumerate(dataloader):  # single graph batch
Mufei Li's avatar
Mufei Li committed
259
        model.train()
260
        g, train_nids, edges, labels = batch_data
Mufei Li's avatar
Mufei Li committed
261
        g = g.to(device)
262
263
        train_nids = train_nids.to(device)
        edges = edges.to(device)
Mufei Li's avatar
Mufei Li committed
264
265
        labels = labels.to(device)

266
267
        embed = model(g, train_nids)
        loss = model.get_loss(embed, edges, labels)
Mufei Li's avatar
Mufei Li committed
268
269
        optimizer.zero_grad()
        loss.backward()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
270
271
272
        nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=1.0
        )  # clip gradients
Mufei Li's avatar
Mufei Li committed
273
        optimizer.step()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
274
275
276
277
278
        print(
            "Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(
                epoch, loss.item(), best_mrr
            )
        )
Mufei Li's avatar
Mufei Li committed
279
280
281
282
        if (epoch + 1) % 500 == 0:
            # perform validation on CPU because full graph is too large
            model = model.cpu()
            model.eval()
283
            embed = model(test_g, test_nids)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
284
            mrr = calc_mrr(
285
                embed, model.w_relation, val_mask, triplets, batch_size=500
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
286
            )
Mufei Li's avatar
Mufei Li committed
287
288
289
            # save best model
            if best_mrr < mrr:
                best_mrr = mrr
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
290
291
292
293
                torch.save(
                    {"state_dict": model.state_dict(), "epoch": epoch},
                    model_state_file,
                )
Mufei Li's avatar
Mufei Li committed
294
295
            model = model.to(device)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
296
297
298
299

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with DGL built-in RGCN module")
300
301
302
303
304
305

    # load and preprocess dataset
    data = FB15k237Dataset(reverse=False)
    g = data[0]
    num_nodes = g.num_nodes()
    num_rels = data.num_rels
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
306
307
308
    train_g = get_subset_g(g, g.edata["train_mask"], num_rels)
    test_g = get_subset_g(g, g.edata["train_mask"], num_rels, bidirected=True)
    test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1)
309
    test_nids = torch.arange(0, num_nodes)
310
    val_mask = g.edata["val_mask"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
311
312
313
314
315
    test_mask = g.edata["test_mask"]
    subg_iter = SubgraphIterator(train_g, num_rels)  # uniform edge sampling
    dataloader = GraphDataLoader(
        subg_iter, batch_size=1, collate_fn=lambda x: x[0]
    )
316
317
318

    # Prepare data for metric computation
    src, dst = g.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
319
    triplets = torch.stack([src, g.edata["etype"], dst], dim=1)
320
321
322
323
324

    # create RGCN model
    model = LinkPredict(num_nodes, num_rels).to(device)

    # train
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
325
326
327
328
329
    model_state_file = "model_state.pth"
    train(
        dataloader,
        test_g,
        test_nids,
330
        val_mask,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
331
332
333
334
335
        triplets,
        device,
        model_state_file,
        model,
    )
336
337
338
339

    # testing
    print("Testing...")
    checkpoint = torch.load(model_state_file)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
340
    model = model.cpu()  # test on CPU
Mufei Li's avatar
Mufei Li committed
341
    model.eval()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
342
    model.load_state_dict(checkpoint["state_dict"])
343
    embed = model(test_g, test_nids)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
344
345
346
347
348
349
350
351
    best_mrr = calc_mrr(
        embed, model.w_relation, test_mask, triplets, batch_size=500
    )
    print(
        "Best MRR {:.4f} achieved using the epoch {:04d}".format(
            best_mrr, checkpoint["epoch"]
        )
    )