Unverified Commit 251a9842 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

fix (#898)

parent 190cdbd2
...@@ -993,6 +993,8 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po ...@@ -993,6 +993,8 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
int neg_sample_size, bool exclude_positive, int neg_sample_size, bool exclude_positive,
bool check_false_neg) { bool check_false_neg) {
int64_t num_tot_nodes = gptr->NumVertices(); int64_t num_tot_nodes = gptr->NumVertices();
if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes;
bool is_multigraph = gptr->IsMultigraph(); bool is_multigraph = gptr->IsMultigraph();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
...@@ -1165,6 +1167,8 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph ...@@ -1165,6 +1167,8 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2; int64_t num_pos_edges = coo->shape[0] / 2;
if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes;
int64_t chunk_size = neg_sample_size; int64_t chunk_size = neg_sample_size;
// If num_pos_edges isn't divisible by chunk_size, the actual number of chunks // If num_pos_edges isn't divisible by chunk_size, the actual number of chunks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment