Unverified Commit dcf16992 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] check whether etype sorted when sampling (#4198)

parent a9768cb3
...@@ -277,7 +277,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, ...@@ -277,7 +277,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t et_offset = 0; int64_t et_offset = 0;
int64_t et_len = 1; int64_t et_len = 1;
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
if ((j+1 == len) || cur_et != et[et_idx[j+1]]) { CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))
<< "Edge type is not sorted. Please sort in advance or specify "
"'etype_sorted' as false.";
if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {
// 1 end of the current etype // 1 end of the current etype
// 2 end of the row // 2 end of the row
// random pick for current etype // random pick for current etype
......
...@@ -858,6 +858,32 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace): ...@@ -858,6 +858,32 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
h_g, seeds, dgl.ETYPE, fanouts, replace=replace, edge_dir=direction) h_g, seeds, dgl.ETYPE, fanouts, replace=replace, edge_dir=direction)
check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction) check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
@pytest.mark.parametrize('format_', ['csr', 'csc'])
@pytest.mark.parametrize('direction', ['in', 'out'])
def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
rare_cnt = 4
g = create_etype_test_graph(100, 30, rare_cnt)
h_g = dgl.to_homogeneous(g)
seed_ntype = g.get_ntype_id("u")
seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)
h_g = h_g.formats(format_)
if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
h_g = h_g.formats(['csc', 'csr', 'coo'])
orig_etype = F.asnumpy(h_g.edata[dgl.ETYPE])
h_g.edata[dgl.ETYPE] = F.tensor(
np.sort(orig_etype)[::-1].tolist(), dtype=F.int64)
try:
dgl.sampling.sample_etype_neighbors(
h_g, seeds, dgl.ETYPE, fanouts, edge_dir=direction, etype_sorted=True)
fail = False
except dgl.DGLError:
fail = True
assert fail
@pytest.mark.parametrize('dtype', ['int32', 'int64']) @pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_exclude_edges_heteroG(dtype): def test_sample_neighbors_exclude_edges_heteroG(dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment