Unverified Commit 91f7ff8e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix #1421 (#1422)

parent 7c47d8c9
...@@ -73,15 +73,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -73,15 +73,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data); IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
bool all_has_fanout = true; bool all_has_fanout = true;
if (replace) {
all_has_fanout = true;
} else {
#pragma omp parallel for reduction(&&:all_has_fanout) #pragma omp parallel for reduction(&&:all_has_fanout)
for (int64_t i = 0; i < num_rows; ++i) { for (int64_t i = 0; i < num_rows; ++i) {
const IdxType rid = rows_data[i]; const IdxType rid = rows_data[i];
const IdxType len = indptr[rid + 1] - indptr[rid]; const IdxType len = indptr[rid + 1] - indptr[rid];
all_has_fanout = all_has_fanout && (len >= num_picks); // If a node has no neighbor then all_has_fanout must be false even if replace is
} // true.
all_has_fanout = all_has_fanout && (len >= (replace ? 1 : num_picks));
} }
#pragma omp parallel for #pragma omp parallel for
......
...@@ -460,10 +460,14 @@ def test_sample_neighbors_topk_outedge(): ...@@ -460,10 +460,14 @@ def test_sample_neighbors_topk_outedge():
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_with_0deg(): def test_sample_neighbors_with_0deg():
g = dgl.graph([], num_nodes=5) g = dgl.graph([], num_nodes=5)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=False) sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=False)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=True) assert sg.number_of_edges() == 0
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=False) sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=True)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=True) assert sg.number_of_edges() == 0
sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=False)
assert sg.number_of_edges() == 0
sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=True)
assert sg.number_of_edges() == 0
if __name__ == '__main__': if __name__ == '__main__':
test_random_walk() test_random_walk()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment