Unverified Commit f7360c3c authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix link prediction (#3485)

parent db78fac5
......@@ -99,7 +99,7 @@ neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]
######################################################################
......
......@@ -395,7 +395,7 @@ test_neg_dst = torch.randint(0, graph.num_nodes(), (graph.num_edges(),))
# You also need to label the edges, 1 if positive and 0 if negative.
#
test_src = torch.cat([test_pos_src, test_neg_src])
test_src = torch.cat([test_pos_src, test_pos_dst])
test_dst = torch.cat([test_neg_src, test_neg_dst])
test_graph = dgl.graph((test_src, test_dst), num_nodes=graph.num_nodes())
test_graph.edata['label'] = torch.cat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment