"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f7dfcfd971131fbaae43d9d9f59e5e3a9aa6234a"
Unverified Commit e26d2064 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix link pred example in graphsage (#4255)

parent 1c9528f5
...@@ -106,7 +106,7 @@ def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500): ...@@ -106,7 +106,7 @@ def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
h_src = node_emb[src[start:end]][:, None, :].to(device) h_src = node_emb[src[start:end]][:, None, :].to(device)
h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device) h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
pred = model.predict(h_src, h_dst).squeeze(-1) pred = model.predict(h_src, h_dst).squeeze(-1)
relevance = torch.zeros(*pred.shape, dtype=torch.bool) relevance = torch.zeros(*pred.shape, dtype=torch.bool).to(pred.device)
relevance[:, 0] = True relevance[:, 0] = True
rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance) rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
return rr.mean() return rr.mean()
...@@ -117,9 +117,9 @@ def evaluate(model, edge_split, device, num_workers): ...@@ -117,9 +117,9 @@ def evaluate(model, edge_split, device, num_workers):
node_emb = model.inference(graph, device, 4096, num_workers, 'cpu') node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
results = [] results = []
for split in ['valid', 'test']: for split in ['valid', 'test']:
src = edge_split[split]['source_node'].to(device) src = edge_split[split]['source_node'].to(node_emb.device)
dst = edge_split[split]['target_node'].to(device) dst = edge_split[split]['target_node'].to(node_emb.device)
neg_dst = edge_split[split]['target_node_neg'].to(device) neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device)
results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device)) results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment