Unverified Commit a5d21c2b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Sampling] handle fanout=-1 differently from fanout>0 in sample_etype_neighbors() (#4716)

parent e452179c
...@@ -284,7 +284,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, ...@@ -284,7 +284,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
// 1 end of the current etype // 1 end of the current etype
// 2 end of the row // 2 end of the row
// random pick for current etype // random pick for current etype
if (et_len <= num_picks[cur_et] && !replace) { if ((num_picks[cur_et] == -1) ||
(et_len <= num_picks[cur_et] && !replace)) {
// fast path, select all // fast path, select all
for (int64_t k = 0; k < et_len; ++k) { for (int64_t k = 0; k < et_len; ++k) {
rows.push_back(rid); rows.push_back(rid);
......
...@@ -882,7 +882,7 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): ...@@ -882,7 +882,7 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
h_g = dgl.to_homogeneous(g) h_g = dgl.to_homogeneous(g)
seed_ntype = g.get_ntype_id("u") seed_ntype = g.get_ntype_id("u")
seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype) seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64) fanouts = F.tensor([6, 5, -1, 3, 2], dtype=F.int64)
h_g = h_g.formats(format_) h_g = h_g.formats(format_)
if (direction, format_) in [('in', 'csr'), ('out', 'csc')]: if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
h_g = h_g.formats(['csc', 'csr', 'coo']) h_g = h_g.formats(['csc', 'csr', 'coo'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment