Unverified Commit 7b972981 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix bug in flatten and is_unibipartite (#2279)

parent cbbbbde7
......@@ -656,25 +656,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph")
DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0];
std::set<int64_t> dst_set;
std::set<int64_t> src_set;
std::unordered_set<int64_t> dst_set;
std::unordered_set<int64_t> src_set;
for (int64_t eid = 0; eid < metagraph->NumEdges(); ++eid) {
auto edge = metagraph->FindEdge(eid);
auto src = edge.first;
auto dst = edge.second;
dst_set.insert(dst);
if (dst_set.count(src))
return;
src_set.insert(src);
}
List<Value> srclist, dstlist;
List<List<Value>> ret_list;
for (auto dst : dst_set)
dstlist.push_back(Value(MakeValue(dst)));
for (int64_t nid = 0 ; nid < metagraph->NumVertices(); ++nid)
if (!dst_set.count(nid))
for (int64_t nid = 0; nid < metagraph->NumVertices(); ++nid) {
auto is_dst = dst_set.count(nid);
auto is_src = src_set.count(nid);
if (is_dst && is_src)
return;
else if (is_dst)
dstlist.push_back(Value(MakeValue(nid)));
else
// If a node type is isolated, put it in srctype as defined in the Python docstring.
srclist.push_back(Value(MakeValue(nid)));
}
ret_list.push_back(srclist);
ret_list.push_back(dstlist);
*rv = ret_list;
......
......@@ -1736,6 +1736,13 @@ def test_bipartite(idtype):
}, idtype=idtype, device=F.ctx())
assert not g3.is_unibipartite
g4 = dgl.heterograph({
('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5]),
('C', 'CA', 'A'): ([1, 0], [0, 0])
}, idtype=idtype, device=F.ctx())
assert not g4.is_unibipartite
@parametrize_dtype
def test_dtype_cast(idtype):
g = dgl.graph(([0, 1, 0, 2], [0, 1, 1, 0]), idtype=idtype, device=F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment