"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "2eaefe2745bff1a898394f9f447c011bc50daedc"
Commit 77c58289 authored by Sahand's avatar Sahand Committed by Minjie Wang
Browse files

[Bugfix][Model] Fixing RGCN evaluation bug (#778)

parent 788420df
...@@ -182,7 +182,7 @@ def main(args): ...@@ -182,7 +182,7 @@ def main(args):
model.cpu() model.cpu()
model.eval() model.eval()
print("start eval") print("start eval")
mrr = utils.evaluate(test_graph, model, valid_data, num_nodes, mrr = utils.evaluate(test_graph, model, valid_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size) hits=[1, 3, 10], eval_bz=args.eval_batch_size)
# save best model # save best model
if mrr < best_mrr: if mrr < best_mrr:
...@@ -207,7 +207,7 @@ def main(args): ...@@ -207,7 +207,7 @@ def main(args):
model.eval() model.eval()
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch'])) print("Using best epoch: {}".format(checkpoint['epoch']))
utils.evaluate(test_graph, model, test_data, num_nodes, hits=[1, 3, 10], utils.evaluate(test_graph, model, test_data, hits=[1, 3, 10],
eval_bz=args.eval_batch_size) eval_bz=args.eval_batch_size)
......
...@@ -163,15 +163,15 @@ def sort_and_rank(score, target): ...@@ -163,15 +163,15 @@ def sort_and_rank(score, target):
indices = indices[:, 1].view(-1) indices = indices[:, 1].view(-1)
return indices return indices
def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100): def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets """ Perturb one element in the triplets
""" """
n_batch = (num_entity + batch_size - 1) // batch_size n_batch = (test_size + batch_size - 1) // batch_size
ranks = [] ranks = []
for idx in range(n_batch): for idx in range(n_batch):
print("batch {} / {}".format(idx, n_batch)) print("batch {} / {}".format(idx, n_batch))
batch_start = idx * batch_size batch_start = idx * batch_size
batch_end = min(num_entity, (idx + 1) * batch_size) batch_end = min(test_size, (idx + 1) * batch_size)
batch_a = a[batch_start: batch_end] batch_a = a[batch_start: batch_end]
batch_r = r[batch_start: batch_end] batch_r = r[batch_start: batch_end]
emb_ar = embedding[batch_a] * w[batch_r] emb_ar = embedding[batch_a] * w[batch_r]
...@@ -187,17 +187,18 @@ def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100): ...@@ -187,17 +187,18 @@ def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100):
# TODO (lingfan): implement filtered metrics # TODO (lingfan): implement filtered metrics
# return MRR (raw), and Hits @ (1, 3, 10) # return MRR (raw), and Hits @ (1, 3, 10)
def evaluate(test_graph, model, test_triplets, num_entity, hits=[], eval_bz=100): def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100):
with torch.no_grad(): with torch.no_grad():
embedding, w = model.evaluate(test_graph) embedding, w = model.evaluate(test_graph)
s = test_triplets[:, 0] s = test_triplets[:, 0]
r = test_triplets[:, 1] r = test_triplets[:, 1]
o = test_triplets[:, 2] o = test_triplets[:, 2]
test_size = test_triplets.shape[0]
# perturb subject # perturb subject
ranks_s = perturb_and_get_rank(embedding, w, o, r, s, num_entity, eval_bz) ranks_s = perturb_and_get_rank(embedding, w, o, r, s, test_size, eval_bz)
# perturb object # perturb object
ranks_o = perturb_and_get_rank(embedding, w, s, r, o, num_entity, eval_bz) ranks_o = perturb_and_get_rank(embedding, w, s, r, o, test_size, eval_bz)
ranks = torch.cat([ranks_s, ranks_o]) ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed ranks += 1 # change to 1-indexed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment