Unverified Commit 97b08fbb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] fix crash in neighbor sampling w/ replacement on 0 degree nodes (#1402)

parent e3a9a6bb
......@@ -90,6 +90,9 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
CHECK_LT(rid, mat.num_rows);
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len == 0)
continue;
if (len <= num_picks && !replace) {
// nnz <= num_picks and w/o replacement, take all nnz
for (int64_t j = 0; j < len; ++j) {
......
......@@ -452,6 +452,14 @@ def test_sample_neighbors_topk_outedge():
_test_sample_neighbors_topk_outedge(False)
_test_sample_neighbors_topk_outedge(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_with_0deg():
g = dgl.graph([], num_nodes=5)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=False)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=True)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=False)
dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=True)
if __name__ == '__main__':
test_random_walk()
test_pack_traces()
......@@ -460,3 +468,4 @@ if __name__ == '__main__':
test_sample_neighbors_outedge()
test_sample_neighbors_topk()
test_sample_neighbors_topk_outedge()
test_sample_neighbors_with_0deg()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment