"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c4ef1efe4649b87cb4fc6ff3547556c4aaa6fb64"
Unverified Commit 9432cd63 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

fix sampler. (#1350)

parent 988c8b20
...@@ -434,14 +434,19 @@ def create_neg_subgraph(pos_g, neg_g, chunk_size, neg_sample_size, is_chunked, ...@@ -434,14 +434,19 @@ def create_neg_subgraph(pos_g, neg_g, chunk_size, neg_sample_size, is_chunked,
# We use all nodes to create negative edges. Regardless of the sampling algorithm, # We use all nodes to create negative edges. Regardless of the sampling algorithm,
# we can always view the subgraph with one chunk. # we can always view the subgraph with one chunk.
if (neg_head and len(neg_g.head_nid) == num_nodes) \ if (neg_head and len(neg_g.head_nid) == num_nodes) \
or (not neg_head and len(neg_g.tail_nid) == num_nodes) \ or (not neg_head and len(neg_g.tail_nid) == num_nodes):
or pos_g.number_of_edges() < chunk_size:
num_chunks = 1 num_chunks = 1
chunk_size = pos_g.number_of_edges() chunk_size = pos_g.number_of_edges()
elif is_chunked: elif is_chunked:
# This is probably the last batch. Let's ignore it. # This is probably for evaluation.
if pos_g.number_of_edges() % chunk_size > 0: if pos_g.number_of_edges() < chunk_size \
and neg_g.number_of_edges() % neg_sample_size == 0:
num_chunks = 1
chunk_size = pos_g.number_of_edges()
# This is probably the last batch in the training. Let's ignore it.
elif pos_g.number_of_edges() % chunk_size > 0:
return None return None
else:
num_chunks = int(pos_g.number_of_edges() / chunk_size) num_chunks = int(pos_g.number_of_edges() / chunk_size)
assert num_chunks * chunk_size == pos_g.number_of_edges() assert num_chunks * chunk_size == pos_g.number_of_edges()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment