Unverified Commit bc3da371 authored by Songqing Zhang's avatar Songqing Zhang Committed by GitHub
Browse files

[Misc] Fix validation dataset's usage in RGCN link example (#7308)

parent 6309483d
...@@ -58,4 +58,4 @@ Summary ...@@ -58,4 +58,4 @@ Summary
### Link Prediction ### Link Prediction
| Dataset | Best MRR | Dataset | Best MRR
| ------------- | ------- | ------------- | -------
| FB15k-237 | ~0.2439 | FB15k-237 | ~0.2397
...@@ -223,11 +223,9 @@ def perturb_and_get_filtered_rank( ...@@ -223,11 +223,9 @@ def perturb_and_get_filtered_rank(
return torch.LongTensor(ranks) return torch.LongTensor(ranks)
def calc_mrr( def calc_mrr(emb, w, mask, triplets_to_filter, batch_size=100, filter=True):
emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True
):
with torch.no_grad(): with torch.no_grad():
test_triplets = triplets_to_filter[test_mask] test_triplets = triplets_to_filter[mask]
s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2] s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2]
test_size = len(s) test_size = len(s)
triplets_to_filter = { triplets_to_filter = {
...@@ -249,7 +247,7 @@ def train( ...@@ -249,7 +247,7 @@ def train(
dataloader, dataloader,
test_g, test_g,
test_nids, test_nids,
test_mask, val_mask,
triplets, triplets,
device, device,
model_state_file, model_state_file,
...@@ -284,7 +282,7 @@ def train( ...@@ -284,7 +282,7 @@ def train(
model.eval() model.eval()
embed = model(test_g, test_nids) embed = model(test_g, test_nids)
mrr = calc_mrr( mrr = calc_mrr(
embed, model.w_relation, test_mask, triplets, batch_size=500 embed, model.w_relation, val_mask, triplets, batch_size=500
) )
# save best model # save best model
if best_mrr < mrr: if best_mrr < mrr:
...@@ -309,6 +307,7 @@ if __name__ == "__main__": ...@@ -309,6 +307,7 @@ if __name__ == "__main__":
test_g = get_subset_g(g, g.edata["train_mask"], num_rels, bidirected=True) test_g = get_subset_g(g, g.edata["train_mask"], num_rels, bidirected=True)
test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1) test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1)
test_nids = torch.arange(0, num_nodes) test_nids = torch.arange(0, num_nodes)
val_mask = g.edata["val_mask"]
test_mask = g.edata["test_mask"] test_mask = g.edata["test_mask"]
subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
dataloader = GraphDataLoader( dataloader = GraphDataLoader(
...@@ -328,7 +327,7 @@ if __name__ == "__main__": ...@@ -328,7 +327,7 @@ if __name__ == "__main__":
dataloader, dataloader,
test_g, test_g,
test_nids, test_nids,
test_mask, val_mask,
triplets, triplets,
device, device,
model_state_file, model_state_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment