Unverified Commit 6c7c4039 authored by hgalioulline's avatar hgalioulline Committed by GitHub
Browse files

[Feature] Filtered MRR metrics for R-GCN example (#1298)



* Add filtered metrics for R-GCN example

* Add new line to end of file

* Add evaluation protocol argument option for R-GCN example

* Update README
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 87842db4
...@@ -40,5 +40,9 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test ...@@ -40,5 +40,9 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test
### Link Prediction ### Link Prediction
FB15k-237: MRR 0.151 (DGL), 0.158 (paper) FB15k-237: MRR 0.151 (DGL), 0.158 (paper)
``` ```
python3 link_predict.py -d FB15k-237 --gpu 0 python3 link_predict.py -d FB15k-237 --gpu 0 --raw
```
FB15k-237: Filtered-MRR 0.2044
```
python3 link_predict.py -d FB15k-237 --gpu 0 --filtered
``` ```
...@@ -186,8 +186,9 @@ def main(args): ...@@ -186,8 +186,9 @@ def main(args):
model.eval() model.eval()
print("start eval") print("start eval")
embed = model(test_graph, test_node_id, test_rel, test_norm) embed = model(test_graph, test_node_id, test_rel, test_norm)
mrr = utils.calc_mrr(embed, model.w_relation, valid_data, mrr = utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data),
hits=[1, 3, 10], eval_bz=args.eval_batch_size) valid_data, test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size,
eval_p=args.eval_protocol)
# save best model # save best model
if mrr < best_mrr: if mrr < best_mrr:
if epoch >= args.n_epochs: if epoch >= args.n_epochs:
...@@ -212,8 +213,8 @@ def main(args): ...@@ -212,8 +213,8 @@ def main(args):
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch'])) print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_graph, test_node_id, test_rel, test_norm) embed = model(test_graph, test_node_id, test_rel, test_norm)
utils.calc_mrr(embed, model.w_relation, test_data, utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data), valid_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size) test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size, eval_p=args.eval_protocol)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') parser = argparse.ArgumentParser(description='RGCN')
...@@ -235,6 +236,8 @@ if __name__ == '__main__': ...@@ -235,6 +236,8 @@ if __name__ == '__main__':
help="dataset to use") help="dataset to use")
parser.add_argument("--eval-batch-size", type=int, default=500, parser.add_argument("--eval-batch-size", type=int, default=500,
help="batch size when evaluating") help="batch size when evaluating")
parser.add_argument("--eval-protocol", type=str, default="filtered",
help="type of evaluation protocol: 'raw' or 'filtered' mrr")
parser.add_argument("--regularization", type=float, default=0.01, parser.add_argument("--regularization", type=float, default=0.01,
help="regularization weight") help="regularization weight")
parser.add_argument("--grad-norm", type=float, default=1.0, parser.add_argument("--grad-norm", type=float, default=1.0,
......
...@@ -165,7 +165,7 @@ def negative_sampling(pos_samples, num_entity, negative_rate): ...@@ -165,7 +165,7 @@ def negative_sampling(pos_samples, num_entity, negative_rate):
####################################################################### #######################################################################
# #
# Utility function for evaluations # Utility functions for evaluations (raw)
# #
####################################################################### #######################################################################
...@@ -175,7 +175,7 @@ def sort_and_rank(score, target): ...@@ -175,7 +175,7 @@ def sort_and_rank(score, target):
indices = indices[:, 1].view(-1) indices = indices[:, 1].view(-1)
return indices return indices
def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): def perturb_and_get_raw_rank(embedding, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets """ Perturb one element in the triplets
""" """
n_batch = (test_size + batch_size - 1) // batch_size n_batch = (test_size + batch_size - 1) // batch_size
...@@ -197,9 +197,8 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): ...@@ -197,9 +197,8 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
ranks.append(sort_and_rank(score, target)) ranks.append(sort_and_rank(score, target))
return torch.cat(ranks) return torch.cat(ranks)
# TODO (lingfan): implement filtered metrics
# return MRR (raw), and Hits @ (1, 3, 10) # return MRR (raw), and Hits @ (1, 3, 10)
def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): def calc_raw_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
with torch.no_grad(): with torch.no_grad():
s = test_triplets[:, 0] s = test_triplets[:, 0]
r = test_triplets[:, 1] r = test_triplets[:, 1]
...@@ -207,9 +206,9 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): ...@@ -207,9 +206,9 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
test_size = test_triplets.shape[0] test_size = test_triplets.shape[0]
# perturb subject # perturb subject
ranks_s = perturb_and_get_rank(embedding, w, o, r, s, test_size, eval_bz) ranks_s = perturb_and_get_raw_rank(embedding, w, o, r, s, test_size, eval_bz)
# perturb object # perturb object
ranks_o = perturb_and_get_rank(embedding, w, s, r, o, test_size, eval_bz) ranks_o = perturb_and_get_raw_rank(embedding, w, s, r, o, test_size, eval_bz)
ranks = torch.cat([ranks_s, ranks_o]) ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed ranks += 1 # change to 1-indexed
...@@ -221,3 +220,117 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): ...@@ -221,3 +220,117 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
avg_count = torch.mean((ranks <= hit).float()) avg_count = torch.mean((ranks <= hit).float())
print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item())) print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item() return mrr.item()
#######################################################################
#
# Utility functions for evaluations (filtered)
#
#######################################################################
def filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities):
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
filtered_o = []
# Do not filter out the test triplet, since we want to predict on it
if (target_s, target_r, target_o) in triplets_to_filter:
triplets_to_filter.remove((target_s, target_r, target_o))
# Do not consider an object if it is part of a triplet to filter
for o in range(num_entities):
if (target_s, target_r, o) not in triplets_to_filter:
filtered_o.append(o)
return torch.LongTensor(filtered_o)
def filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities):
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
filtered_s = []
# Do not filter out the test triplet, since we want to predict on it
if (target_s, target_r, target_o) in triplets_to_filter:
triplets_to_filter.remove((target_s, target_r, target_o))
# Do not consider a subject if it is part of a triplet to filter
for s in range(num_entities):
if (s, target_r, target_o) not in triplets_to_filter:
filtered_s.append(s)
return torch.LongTensor(filtered_s)
def perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
""" Perturb object in the triplets
"""
num_entities = embedding.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
filtered_o = filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities)
target_o_idx = int((filtered_o == target_o).nonzero())
emb_s = embedding[target_s]
emb_r = w[target_r]
emb_o = embedding[filtered_o]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_o_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)
def perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
""" Perturb subject in the triplets
"""
num_entities = embedding.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
filtered_s = filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities)
target_s_idx = int((filtered_s == target_s).nonzero())
emb_s = embedding[filtered_s]
emb_r = w[target_r]
emb_o = embedding[target_o]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_s_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)
def calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[]):
with torch.no_grad():
s = test_triplets[:, 0]
r = test_triplets[:, 1]
o = test_triplets[:, 2]
test_size = test_triplets.shape[0]
triplets_to_filter = torch.cat([train_triplets, valid_triplets, test_triplets]).tolist()
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter}
print('Perturbing subject...')
ranks_s = perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)
print('Perturbing object...')
ranks_o = perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)
ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = torch.mean(1.0 / ranks.float())
print("MRR (filtered): {:.6f}".format(mrr.item()))
for hit in hits:
avg_count = torch.mean((ranks <= hit).float())
print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item()
#######################################################################
#
# Main evaluation function
#
#######################################################################
def calc_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[], eval_bz=100, eval_p="filtered"):
if eval_p == "filtered":
mrr = calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits)
else:
mrr = calc_raw_mrr(embedding, w, test_triplets, hits, eval_bz)
return mrr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment