"src/rpc/network/socket_communicator.h" did not exist on "5d494c620757c2f4c1c70e735da892f11bef32c9"
Unverified Commit 7b972981 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix bug in flatten and is_unibipartite (#2279)

parent cbbbbde7
...@@ -656,25 +656,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph") ...@@ -656,25 +656,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph")
DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0]; GraphRef metagraph = args[0];
std::set<int64_t> dst_set; std::unordered_set<int64_t> dst_set;
std::set<int64_t> src_set; std::unordered_set<int64_t> src_set;
for (int64_t eid = 0; eid < metagraph->NumEdges(); ++eid) { for (int64_t eid = 0; eid < metagraph->NumEdges(); ++eid) {
auto edge = metagraph->FindEdge(eid); auto edge = metagraph->FindEdge(eid);
auto src = edge.first; auto src = edge.first;
auto dst = edge.second; auto dst = edge.second;
dst_set.insert(dst); dst_set.insert(dst);
if (dst_set.count(src)) src_set.insert(src);
return;
} }
List<Value> srclist, dstlist; List<Value> srclist, dstlist;
List<List<Value>> ret_list; List<List<Value>> ret_list;
for (auto dst : dst_set) for (int64_t nid = 0; nid < metagraph->NumVertices(); ++nid) {
dstlist.push_back(Value(MakeValue(dst))); auto is_dst = dst_set.count(nid);
for (int64_t nid = 0 ; nid < metagraph->NumVertices(); ++nid) auto is_src = src_set.count(nid);
if (!dst_set.count(nid)) if (is_dst && is_src)
return;
else if (is_dst)
dstlist.push_back(Value(MakeValue(nid)));
else
// If a node type is isolated, put it in srctype as defined in the Python docstring.
srclist.push_back(Value(MakeValue(nid))); srclist.push_back(Value(MakeValue(nid)));
}
ret_list.push_back(srclist); ret_list.push_back(srclist);
ret_list.push_back(dstlist); ret_list.push_back(dstlist);
*rv = ret_list; *rv = ret_list;
......
...@@ -1736,6 +1736,13 @@ def test_bipartite(idtype): ...@@ -1736,6 +1736,13 @@ def test_bipartite(idtype):
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
assert not g3.is_unibipartite assert not g3.is_unibipartite
g4 = dgl.heterograph({
('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5]),
('C', 'CA', 'A'): ([1, 0], [0, 0])
}, idtype=idtype, device=F.ctx())
assert not g4.is_unibipartite
@parametrize_dtype @parametrize_dtype
def test_dtype_cast(idtype): def test_dtype_cast(idtype):
g = dgl.graph(([0, 1, 0, 2], [0, 1, 1, 0]), idtype=idtype, device=F.ctx()) g = dgl.graph(([0, 1, 0, 2], [0, 1, 1, 0]), idtype=idtype, device=F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment