Unverified Commit 97b08fbb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] fix crash in neighbor sampling w/ replacement on 0 degree nodes (#1402)

parent e3a9a6bb
...@@ -90,6 +90,9 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -90,6 +90,9 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
CHECK_LT(rid, mat.num_rows); CHECK_LT(rid, mat.num_rows);
const IdxType off = indptr[rid]; const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off; const IdxType len = indptr[rid + 1] - off;
if (len == 0)
continue;
if (len <= num_picks && !replace) { if (len <= num_picks && !replace) {
// nnz <= num_picks and w/o replacement, take all nnz // nnz <= num_picks and w/o replacement, take all nnz
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
......
...@@ -452,6 +452,14 @@ def test_sample_neighbors_topk_outedge(): ...@@ -452,6 +452,14 @@ def test_sample_neighbors_topk_outedge():
_test_sample_neighbors_topk_outedge(False) _test_sample_neighbors_topk_outedge(False)
_test_sample_neighbors_topk_outedge(True) _test_sample_neighbors_topk_outedge(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_with_0deg():
g = dgl.graph([], num_nodes=5)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=False)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=True)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=False)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=True)
if __name__ == '__main__': if __name__ == '__main__':
test_random_walk() test_random_walk()
test_pack_traces() test_pack_traces()
...@@ -460,3 +468,4 @@ if __name__ == '__main__': ...@@ -460,3 +468,4 @@ if __name__ == '__main__':
test_sample_neighbors_outedge() test_sample_neighbors_outedge()
test_sample_neighbors_topk() test_sample_neighbors_topk()
test_sample_neighbors_topk_outedge() test_sample_neighbors_topk_outedge()
test_sample_neighbors_with_0deg()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment