Unverified Commit e26d2064 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix link pred example in graphsage (#4255)

parent 1c9528f5
......@@ -106,7 +106,7 @@ def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
h_src = node_emb[src[start:end]][:, None, :].to(device)
h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
pred = model.predict(h_src, h_dst).squeeze(-1)
relevance = torch.zeros(*pred.shape, dtype=torch.bool)
relevance = torch.zeros(*pred.shape, dtype=torch.bool).to(pred.device)
relevance[:, 0] = True
rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
return rr.mean()
......@@ -117,9 +117,9 @@ def evaluate(model, edge_split, device, num_workers):
node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
results = []
for split in ['valid', 'test']:
src = edge_split[split]['source_node'].to(device)
dst = edge_split[split]['target_node'].to(device)
neg_dst = edge_split[split]['target_node_neg'].to(device)
src = edge_split[split]['source_node'].to(node_emb.device)
dst = edge_split[split]['target_node'].to(node_emb.device)
neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device)
results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment