Unverified Commit 2630d2eb authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance][bugfix] Implement `is_unibipartite` in C++ with caching. (#4556)

* updates

* Enable caching C++ result

* Add missing docstring

* Remove unused function

* Add unit test

* Address comments
parent cded5b80
...@@ -149,6 +149,37 @@ class GraphInterface : public runtime::Object { ...@@ -149,6 +149,37 @@ class GraphInterface : public runtime::Object {
*/ */
virtual bool IsMultigraph() const = 0; virtual bool IsMultigraph() const = 0;
/*!
* \return whether the graph is unibipartite
*/
virtual bool IsUniBipartite() const {
EdgeArray edges = Edges();
IdArray src = edges.src;
IdArray dst = edges.dst;
bool is_unibipartite = true;
const size_t n = edges.src.NumElements();
ATEN_ID_TYPE_SWITCH(src->dtype, IdType, {
auto src_v = src.ToVector<IdType>();
std::sort(src_v.begin(), src_v.end());
auto dst_v = dst.ToVector<IdType>();
std::sort(dst_v.begin(), dst_v.end());
// std::set_intersection() requires output, so this is better
for (size_t i = 0, j = 0; i < n && j < n;) {
if (src_v[i] < dst_v[j]) {
++i;
} else if (src_v[i] == dst_v[j]) {
is_unibipartite = false;
break;
} else {
++j;
}
}
});
return is_unibipartite;
}
/*! /*!
* \return whether the graph is read-only * \return whether the graph is read-only
*/ */
......
...@@ -601,6 +601,20 @@ class ImmutableGraph: public GraphInterface { ...@@ -601,6 +601,20 @@ class ImmutableGraph: public GraphInterface {
return true; return true;
} }
/**
* \brief Check if the graph is unibipartite.
*
* @return True if the graph is unibipartite.
*/
bool IsUniBipartite() const override {
if (!is_unibipartite_set_) {
is_unibipartite_ = GraphInterface::IsUniBipartite();
is_unibipartite_set_ = true;
}
return is_unibipartite_;
}
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override { uint64_t NumVertices() const override {
return AnyGraph()->NumVertices(); return AnyGraph()->NumVertices();
...@@ -1000,6 +1014,12 @@ class ImmutableGraph: public GraphInterface { ...@@ -1000,6 +1014,12 @@ class ImmutableGraph: public GraphInterface {
std::string shared_mem_name_; std::string shared_mem_name_;
// We serialize the metadata of the graph index here for shared memory. // We serialize the metadata of the graph index here for shared memory.
NDArray serialized_shared_meta_; NDArray serialized_shared_meta_;
// Whether or not the `is_unibipartite_` property has been set.
mutable bool is_unibipartite_set_ = false;
// Whether this graph is unibipartite. If `is_unibipartite_set_` is false,
// then this flag should be considered in an unititialized state.
mutable bool is_unibipartite_ = false;
}; };
// inline implementations // inline implementations
......
...@@ -97,7 +97,7 @@ class DGLHeteroGraph(object): ...@@ -97,7 +97,7 @@ class DGLHeteroGraph(object):
errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format( errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format(
ntypes) ntypes)
raise TypeError(errmsg) raise TypeError(errmsg)
if not is_unibipartite(self._graph.metagraph): if not self._graph.is_metagraph_unibipartite():
raise ValueError('Invalid input. The metagraph must be a uni-directional' raise ValueError('Invalid input. The metagraph must be a uni-directional'
' bipartite graph.') ' bipartite graph.')
self._ntypes = ntypes[0] + ntypes[1] self._ntypes = ntypes[0] + ntypes[1]
...@@ -6204,23 +6204,6 @@ def make_canonical_etypes(etypes, ntypes, metagraph): ...@@ -6204,23 +6204,6 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)] rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
return rst return rst
def is_unibipartite(graph):
"""Internal function that returns whether the given graph is a uni-directional
bipartite graph.
Parameters
----------
graph : GraphIndex
Input graph
Returns
-------
bool
True if the graph is a uni-bipartite.
"""
src, dst, _ = graph.edges()
return set(src.tonumpy()).isdisjoint(set(dst.tonumpy()))
def find_src_dst_ntypes(ntypes, metagraph): def find_src_dst_ntypes(ntypes, metagraph):
"""Internal function to split ntypes into SRC and DST categories. """Internal function to split ntypes into SRC and DST categories.
......
...@@ -72,6 +72,10 @@ class HeteroGraphIndex(ObjectBase): ...@@ -72,6 +72,10 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroGetMetaGraph(self) return _CAPI_DGLHeteroGetMetaGraph(self)
def is_metagraph_unibipartite(self):
"""Return whether or not the graph is unibiparite."""
return _CAPI_DGLHeteroIsMetaGraphUniBipartite(self)
def number_of_ntypes(self): def number_of_ntypes(self):
"""Return number of node types.""" """Return number of node types."""
return self.metagraph.number_of_nodes() return self.metagraph.number_of_nodes()
......
...@@ -103,7 +103,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNo ...@@ -103,7 +103,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNo
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->meta_graph()); *rv = hg->meta_graph();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMetaGraphUniBipartite")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
GraphPtr mg = hg->meta_graph();
*rv = mg->IsUniBipartite();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
......
...@@ -1884,6 +1884,35 @@ def test_ismultigraph(idtype): ...@@ -1884,6 +1884,35 @@ def test_ismultigraph(idtype):
{'A': 6, 'C': 6}, idtype=idtype, device=F.ctx()) {'A': 6, 'C': 6}, idtype=idtype, device=F.ctx())
assert g.is_multigraph == True assert g.is_multigraph == True
@parametrize_idtype
def test_graph_index_is_unibipartite(idtype):
g1 = dgl.heterograph({('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5])},
idtype=idtype, device=F.ctx())
assert g1._graph.is_metagraph_unibipartite()
# more complicated bipartite
g2 = dgl.heterograph({
('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5]),
('A', 'AC', 'C'): ([1, 0], [0, 0])
}, idtype=idtype, device=F.ctx())
assert g2._graph.is_metagraph_unibipartite()
g3 = dgl.heterograph({
('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5]),
('A', 'AC', 'C'): ([1, 0], [0, 0]),
('A', 'AA', 'A'): ([0, 1], [0, 1])
}, idtype=idtype, device=F.ctx())
assert not g3._graph.is_metagraph_unibipartite()
g4 = dgl.heterograph({
('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5]),
('C', 'CA', 'A'): ([1, 0], [0, 0])
}, idtype=idtype, device=F.ctx())
assert not g4._graph.is_metagraph_unibipartite()
@parametrize_idtype @parametrize_idtype
def test_bipartite(idtype): def test_bipartite(idtype):
g1 = dgl.heterograph({('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5])}, g1 = dgl.heterograph({('A', 'AB', 'B'): ([0, 0, 1], [1, 2, 5])},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment