Unverified Commit be936da8 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

use FFI for subgraph. (#781)

parent 0f127637
...@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object { ...@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
DGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface); DGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface);
/*! \brief Subgraph data structure */ /*! \brief Subgraph data structure */
struct Subgraph { struct Subgraph : public runtime::Object {
/*! \brief The graph. */ /*! \brief The graph. */
GraphPtr graph; GraphPtr graph;
/*! /*!
...@@ -367,8 +367,14 @@ struct Subgraph { ...@@ -367,8 +367,14 @@ struct Subgraph {
* \note This is also a map from the new edge id to the edge id in the parent graph. * \note This is also a map from the new edge id to the edge id in the parent graph.
*/ */
IdArray induced_edges; IdArray induced_edges;
static constexpr const char* _type_key = "graph.Subgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
}; };
// Define SubgraphRef
DGL_DEFINE_OBJECT_REF(SubgraphRef, Subgraph);
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_INTERFACE_H_ #endif // DGL_GRAPH_INTERFACE_H_
...@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase): ...@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
The subgraph index. The subgraph index.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self, v_array) return _CAPI_DGLGraphVertexSubgraph(self, v_array)
induced_edges = utils.toindex(rst(2))
gidx = rst(0)
return SubgraphIndex(gidx, self, v, induced_edges)
def node_subgraphs(self, vs_arr): def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs. """Return the induced node subgraphs.
...@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase): ...@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
The subgraph index. The subgraph index.
""" """
e_array = e.todgltensor() e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes) return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
induced_nodes = utils.toindex(rst(1))
gidx = rst(0)
return SubgraphIndex(gidx, self, induced_nodes, e)
@utils.cached_member(cache='_cache', prefix='scipy_adj') @utils.cached_member(cache='_cache', prefix='scipy_adj')
def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None): def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):
...@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase): ...@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
""" """
return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits)) return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))
class SubgraphIndex(object): @register_object('graph.Subgraph')
"""Internal subgraph data structure. class SubgraphIndex(ObjectBase):
"""Subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
Parameters Returns
---------- -------
graph : GraphIndex GraphIndex
The graph structure of this subgraph. The subgraph
parent : GraphIndex """
The parent graph index. return _CAPI_DGLSubgraphGetGraph(self)
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, graph, parent, induced_nodes, induced_edges):
self.graph = graph
self.parent = parent
self.induced_nodes = induced_nodes
self.induced_edges = induced_edges
def __getstate__(self): @property
raise NotImplementedError( def induced_nodes(self):
"SubgraphIndex pickling is not supported yet.") """Induced nodes for each node type. The return list
length should be equal to the number of node types.
def __setstate__(self, state): Returns
raise NotImplementedError( -------
"SubgraphIndex unpickling is not supported yet.") list of utils.Index
Induced nodes
"""
ret = _CAPI_DGLSubgraphGetInducedVertices(self)
return utils.toindex(ret)
@property
def induced_edges(self):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret = _CAPI_DGLSubgraphGetInducedEdges(self)
return utils.toindex(ret)
############################################################### ###############################################################
......
...@@ -19,27 +19,6 @@ using dgl::runtime::NDArray; ...@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace {
// Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
const int which = args[0];
if (which == 0) {
*rv = GraphRef(sg.graph);
} else if (which == 1) {
*rv = std::move(sg.induced_vertices);
} else if (which == 2) {
*rv = std::move(sg.induced_edges);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
} // namespace
///////////////////////////// Graph API /////////////////////////////////// ///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
...@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph") ...@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = ConvertSubgraphToPackedFunc(g->VertexSubgraph(vids)); std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids)));
*rv = SubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
...@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") ...@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray eids = args[1]; const IdArray eids = args[1];
bool preserve_nodes = args[2]; bool preserve_nodes = args[2];
*rv = ConvertSubgraphToPackedFunc(g->EdgeSubgraph(eids, preserve_nodes)); std::shared_ptr<Subgraph> subg(
new Subgraph(g->EdgeSubgraph(eids, preserve_nodes)));
*rv = SubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
...@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits") ...@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
*rv = g->NumBits(); *rv = g->NumBits();
}); });
// Subgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = GraphRef(subg->graph);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = subg->induced_vertices;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = subg->induced_edges;
});
} // namespace dgl } // namespace dgl
...@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const { ...@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids); const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context()); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids)); CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
return Subgraph{subcsr, vids, submat.data}; Subgraph subg;
subg.graph = subcsr;
subg.induced_vertices = vids;
subg.induced_edges = submat.data;
return subg;
} }
CSRPtr CSR::Transpose() const { CSRPtr CSR::Transpose() const {
...@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const { ...@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array."; CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
COOPtr subcoo;
IdArray induced_nodes;
if (!preserve_nodes) { if (!preserve_nodes) {
IdArray new_src = aten::IndexSelect(adj_.row, eids); IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids); IdArray new_dst = aten::IndexSelect(adj_.col, eids);
IdArray induced_nodes = aten::Relabel_({new_src, new_dst}); induced_nodes = aten::Relabel_({new_src, new_dst});
const auto new_nnodes = induced_nodes->shape[0]; const auto new_nnodes = induced_nodes->shape[0];
COOPtr subcoo(new COO(new_nnodes, new_src, new_dst)); subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids};
} else { } else {
IdArray new_src = aten::IndexSelect(adj_.row, eids); IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids); IdArray new_dst = aten::IndexSelect(adj_.col, eids);
IdArray induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context()); induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
COOPtr subcoo(new COO(NumVertices(), new_src, new_dst)); subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids};
} }
Subgraph subg;
subg.graph = subcoo;
subg.induced_vertices = induced_nodes;
subg.induced_edges = eids;
return subg;
} }
CSRPtr COO::ToCSR() const { CSRPtr COO::ToCSR() const {
...@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const { ...@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
// We prefer to generate a subgraph from out-csr. // We prefer to generate a subgraph from out-csr.
auto sg = GetOutCSR()->VertexSubgraph(vids); auto sg = GetOutCSR()->VertexSubgraph(vids);
CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph); CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
return Subgraph{GraphPtr(new ImmutableGraph(subcsr)), sg.graph = GraphPtr(new ImmutableGraph(subcsr));
sg.induced_vertices, sg.induced_edges}; return sg;
} }
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes); auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph); COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
return Subgraph{GraphPtr(new ImmutableGraph(subcoo)), sg.graph = GraphPtr(new ImmutableGraph(subcoo));
sg.induced_vertices, sg.induced_edges}; return sg;
} }
std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const { std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment