Unverified Commit bc3da371 authored by Songqing Zhang's avatar Songqing Zhang Committed by GitHub
Browse files

[Misc] Fix validation dataset's usage in RGCN link example (#7308)

parent 6309483d
......@@ -58,4 +58,4 @@ Summary
### Link Prediction
| Dataset | Best MRR
| ------------- | -------
| FB15k-237 | ~0.2439
| FB15k-237 | ~0.2397
......@@ -223,11 +223,9 @@ def perturb_and_get_filtered_rank(
return torch.LongTensor(ranks)
def calc_mrr(
emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True
):
def calc_mrr(emb, w, mask, triplets_to_filter, batch_size=100, filter=True):
with torch.no_grad():
test_triplets = triplets_to_filter[test_mask]
test_triplets = triplets_to_filter[mask]
s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2]
test_size = len(s)
triplets_to_filter = {
......@@ -249,7 +247,7 @@ def train(
dataloader,
test_g,
test_nids,
test_mask,
val_mask,
triplets,
device,
model_state_file,
......@@ -284,7 +282,7 @@ def train(
model.eval()
embed = model(test_g, test_nids)
mrr = calc_mrr(
embed, model.w_relation, test_mask, triplets, batch_size=500
embed, model.w_relation, val_mask, triplets, batch_size=500
)
# save best model
if best_mrr < mrr:
......@@ -309,6 +307,7 @@ if __name__ == "__main__":
test_g = get_subset_g(g, g.edata["train_mask"], num_rels, bidirected=True)
test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1)
test_nids = torch.arange(0, num_nodes)
val_mask = g.edata["val_mask"]
test_mask = g.edata["test_mask"]
subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
dataloader = GraphDataLoader(
......@@ -328,7 +327,7 @@ if __name__ == "__main__":
dataloader,
test_g,
test_nids,
test_mask,
val_mask,
triplets,
device,
model_state_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment