Unverified Commit a5d21c2b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Sampling] handle fanout=-1 differently from fanout>0 in sample_etype_neighbors() (#4716)

parent e452179c
......@@ -284,7 +284,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
// 1 end of the current etype
// 2 end of the row
// random pick for current etype
if (et_len <= num_picks[cur_et] && !replace) {
if ((num_picks[cur_et] == -1) ||
(et_len <= num_picks[cur_et] && !replace)) {
// fast path, select all
for (int64_t k = 0; k < et_len; ++k) {
rows.push_back(rid);
......
......@@ -882,7 +882,7 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
h_g = dgl.to_homogeneous(g)
seed_ntype = g.get_ntype_id("u")
seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)
fanouts = F.tensor([6, 5, -1, 3, 2], dtype=F.int64)
h_g = h_g.formats(format_)
if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
h_g = h_g.formats(['csc', 'csr', 'coo'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment