".github/vscode:/vscode.git/clone" did not exist on "26dc54ec0c69e0314ec0911cca6426e1c7a1133b"
link.py 11.1 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
2
3
import numpy as np
import torch
Mufei Li's avatar
Mufei Li committed
4
5
import torch.nn as nn
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
import tqdm
Mufei Li's avatar
Mufei Li committed
7
8
from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader
9
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10

11
12
13
14
15
16

# for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
    src, dst = g.edges()
    sub_src = src[mask]
    sub_dst = dst[mask]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
    sub_rel = g.edata["etype"][mask]
18
19

    if bidirected:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
21
22
        sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat(
            [sub_dst, sub_src]
        )
23
24
25
26
27
28
        sub_rel = torch.cat([sub_rel, sub_rel + num_rels])

    sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
    sub_g.edata[dgl.ETYPE] = sub_rel
    return sub_g

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
29

30
31
32
33
34
35
36
37
class GlobalUniform:
    def __init__(self, g, sample_size):
        self.sample_size = sample_size
        self.eids = np.arange(g.num_edges())

    def sample(self):
        return torch.from_numpy(np.random.choice(self.eids, self.sample_size))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
38

39
class NegativeSampler:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
40
    def __init__(self, k=10):  # negative sampling rate = 10
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        self.k = k

    def sample(self, pos_samples, num_nodes):
        batch_size = len(pos_samples)
        neg_batch_size = batch_size * self.k
        neg_samples = np.tile(pos_samples, (self.k, 1))

        values = np.random.randint(num_nodes, size=neg_batch_size)
        choices = np.random.uniform(size=neg_batch_size)
        subj = choices > 0.5
        obj = choices <= 0.5
        neg_samples[subj, 0] = values[subj]
        neg_samples[obj, 2] = values[obj]
        samples = np.concatenate((pos_samples, neg_samples))

        # binary labels indicating positive and negative samples
        labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
        labels[:batch_size] = 1

        return torch.from_numpy(samples), torch.from_numpy(labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62

63
64
65
66
67
68
69
70
class SubgraphIterator:
    def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
        self.g = g
        self.num_rels = num_rels
        self.sample_size = sample_size
        self.num_epochs = num_epochs
        self.pos_sampler = GlobalUniform(g, sample_size)
        self.neg_sampler = NegativeSampler()
Mufei Li's avatar
Mufei Li committed
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def __len__(self):
        return self.num_epochs

    def __getitem__(self, i):
        eids = self.pos_sampler.sample()
        src, dst = self.g.find_edges(eids)
        src, dst = src.numpy(), dst.numpy()
        rel = self.g.edata[dgl.ETYPE][eids].numpy()

        # relabel nodes to have consecutive node IDs
        uniq_v, edges = np.unique((src, dst), return_inverse=True)
        num_nodes = len(uniq_v)
        # edges is the concatenation of src, dst with relabeled ID
        src, dst = np.reshape(edges, (2, -1))
        relabeled_data = np.stack((src, rel, dst)).transpose()

        samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)

        # use only half of the positive edges
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
91
92
93
94
95
        chosen_ids = np.random.choice(
            np.arange(self.sample_size),
            size=int(self.sample_size / 2),
            replace=False,
        )
96
97
98
99
100
101
102
        src = src[chosen_ids]
        dst = dst[chosen_ids]
        rel = rel[chosen_ids]
        src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
        rel = np.concatenate((rel, rel + self.num_rels))
        sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
        sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
103
        sub_g.edata["norm"] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
104
105
106
107
        uniq_v = torch.from_numpy(uniq_v).view(-1).long()

        return sub_g, uniq_v, samples, labels

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
108

109
110
111
112
113
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, num_rels):
        super().__init__()
        # two-layer RGCN
        self.emb = nn.Embedding(num_nodes, h_dim)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="bdd",
            num_bases=100,
            self_loop=True,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="bdd",
            num_bases=100,
            self_loop=True,
        )
130
131
132
133
        self.dropout = nn.Dropout(0.2)

    def forward(self, g, nids):
        x = self.emb(nids)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
134
        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"]))
135
        h = self.dropout(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
137
        return self.dropout(h)
Mufei Li's avatar
Mufei Li committed
138

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139

Mufei Li's avatar
Mufei Li committed
140
class LinkPredict(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
141
    def __init__(self, num_nodes, num_rels, h_dim=500, reg_param=0.01):
142
143
        super().__init__()
        self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
Mufei Li's avatar
Mufei Li committed
144
        self.reg_param = reg_param
145
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
147
148
        nn.init.xavier_uniform_(
            self.w_relation, gain=nn.init.calculate_gain("relu")
        )
Mufei Li's avatar
Mufei Li committed
149
150

    def calc_score(self, embedding, triplets):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
151
152
153
        s = embedding[triplets[:, 0]]
        r = self.w_relation[triplets[:, 1]]
        o = embedding[triplets[:, 2]]
154
        score = torch.sum(s * r * o, dim=1)
Mufei Li's avatar
Mufei Li committed
155
156
        return score

157
    def forward(self, g, nids):
158
        return self.rgcn(g, nids)
Mufei Li's avatar
Mufei Li committed
159
160

    def regularization_loss(self, embedding):
161
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))
Mufei Li's avatar
Mufei Li committed
162
163
164
165
166
167
168
169

    def get_loss(self, embed, triplets, labels):
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
170
171
172
173

def filter(
    triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True
):
174
175
176
177
178
179
180
181
    """Get candidate heads or tails to score"""
    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
    # Add the ground truth node first
    if filter_o:
        candidate_nodes = [target_o]
    else:
        candidate_nodes = [target_s]
    for e in range(num_nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
182
183
184
        triplet = (
            (target_s, target_r, e) if filter_o else (e, target_r, target_o)
        )
185
186
187
188
        # Do not consider a node if it leads to a real triplet
        if triplet not in triplets_to_filter:
            candidate_nodes.append(e)
    return torch.LongTensor(candidate_nodes)
Mufei Li's avatar
Mufei Li committed
189

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190
191
192
193

def perturb_and_get_filtered_rank(
    emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True
):
194
195
196
197
198
199
200
    """Perturb subject or object in the triplets"""
    num_nodes = emb.shape[0]
    ranks = []
    for idx in tqdm.tqdm(range(test_size), desc="Evaluate"):
        target_s = s[idx]
        target_r = r[idx]
        target_o = o[idx]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
201
202
203
204
205
206
207
208
        candidate_nodes = filter(
            triplets_to_filter,
            target_s,
            target_r,
            target_o,
            num_nodes,
            filter_o=filter_o,
        )
209
210
211
212
213
214
215
216
217
218
        if filter_o:
            emb_s = emb[target_s]
            emb_o = emb[candidate_nodes]
        else:
            emb_s = emb[candidate_nodes]
            emb_o = emb[target_o]
        target_idx = 0
        emb_r = w[target_r]
        emb_triplet = emb_s * emb_r * emb_o
        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
Mufei Li's avatar
Mufei Li committed
219

220
221
222
223
        _, indices = torch.sort(scores, descending=True)
        rank = int((indices == target_idx).nonzero())
        ranks.append(rank)
    return torch.LongTensor(ranks)
Mufei Li's avatar
Mufei Li committed
224

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
225
226
227
228

def calc_mrr(
    emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True
):
229
230
    with torch.no_grad():
        test_triplets = triplets_to_filter[test_mask]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
231
        s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2]
232
        test_size = len(s)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
233
234
235
236
237
238
239
240
241
        triplets_to_filter = {
            tuple(triplet) for triplet in triplets_to_filter.tolist()
        }
        ranks_s = perturb_and_get_filtered_rank(
            emb, w, s, r, o, test_size, triplets_to_filter, filter_o=False
        )
        ranks_o = perturb_and_get_filtered_rank(
            emb, w, s, r, o, test_size, triplets_to_filter
        )
242
        ranks = torch.cat([ranks_s, ranks_o])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
243
        ranks += 1  # change to 1-indexed
244
245
        mrr = torch.mean(1.0 / ranks.float()).item()
    return mrr
Mufei Li's avatar
Mufei Li committed
246

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
247
248
249
250
251
252
253
254
255
256
257

def train(
    dataloader,
    test_g,
    test_nids,
    test_mask,
    triplets,
    device,
    model_state_file,
    model,
):
258
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
Mufei Li's avatar
Mufei Li committed
259
    best_mrr = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
260
    for epoch, batch_data in enumerate(dataloader):  # single graph batch
Mufei Li's avatar
Mufei Li committed
261
        model.train()
262
        g, train_nids, edges, labels = batch_data
Mufei Li's avatar
Mufei Li committed
263
        g = g.to(device)
264
265
        train_nids = train_nids.to(device)
        edges = edges.to(device)
Mufei Li's avatar
Mufei Li committed
266
267
        labels = labels.to(device)

268
269
        embed = model(g, train_nids)
        loss = model.get_loss(embed, edges, labels)
Mufei Li's avatar
Mufei Li committed
270
271
        optimizer.zero_grad()
        loss.backward()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
272
273
274
        nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=1.0
        )  # clip gradients
Mufei Li's avatar
Mufei Li committed
275
        optimizer.step()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
276
277
278
279
280
        print(
            "Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(
                epoch, loss.item(), best_mrr
            )
        )
Mufei Li's avatar
Mufei Li committed
281
282
283
284
        if (epoch + 1) % 500 == 0:
            # perform validation on CPU because full graph is too large
            model = model.cpu()
            model.eval()
285
            embed = model(test_g, test_nids)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
286
287
288
            mrr = calc_mrr(
                embed, model.w_relation, test_mask, triplets, batch_size=500
            )
Mufei Li's avatar
Mufei Li committed
289
290
291
            # save best model
            if best_mrr < mrr:
                best_mrr = mrr
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
292
293
294
295
                torch.save(
                    {"state_dict": model.state_dict(), "epoch": epoch},
                    model_state_file,
                )
Mufei Li's avatar
Mufei Li committed
296
297
            model = model.to(device)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
298
299
300
301

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with DGL built-in RGCN module")
302
303
304
305
306
307

    # load and preprocess dataset
    data = FB15k237Dataset(reverse=False)
    g = data[0]
    num_nodes = g.num_nodes()
    num_rels = data.num_rels
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
308
309
310
    train_g = get_subset_g(g, g.edata["train_mask"], num_rels)
    test_g = get_subset_g(g, g.edata["train_mask"], num_rels, bidirected=True)
    test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1)
311
    test_nids = torch.arange(0, num_nodes)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
312
313
314
315
316
    test_mask = g.edata["test_mask"]
    subg_iter = SubgraphIterator(train_g, num_rels)  # uniform edge sampling
    dataloader = GraphDataLoader(
        subg_iter, batch_size=1, collate_fn=lambda x: x[0]
    )
317
318
319

    # Prepare data for metric computation
    src, dst = g.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
320
    triplets = torch.stack([src, g.edata["etype"], dst], dim=1)
321
322
323
324
325

    # create RGCN model
    model = LinkPredict(num_nodes, num_rels).to(device)

    # train
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
326
327
328
329
330
331
332
333
334
335
336
    model_state_file = "model_state.pth"
    train(
        dataloader,
        test_g,
        test_nids,
        test_mask,
        triplets,
        device,
        model_state_file,
        model,
    )
337
338
339
340

    # testing
    print("Testing...")
    checkpoint = torch.load(model_state_file)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
341
    model = model.cpu()  # test on CPU
Mufei Li's avatar
Mufei Li committed
342
    model.eval()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
343
    model.load_state_dict(checkpoint["state_dict"])
344
    embed = model(test_g, test_nids)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
345
346
347
348
349
350
351
352
    best_mrr = calc_mrr(
        embed, model.w_relation, test_mask, triplets, batch_size=500
    )
    print(
        "Best MRR {:.4f} achieved using the epoch {:04d}".format(
            best_mrr, checkpoint["epoch"]
        )
    )