"examples/vscode:/vscode.git/clone" did not exist on "f99725adbc9cef2fd6cc6eef18a86e3d7c1e5339"
Unverified Commit be936da8 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

use FFI for subgraph. (#781)

parent 0f127637
......@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
DGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface);
/*! \brief Subgraph data structure */
struct Subgraph {
struct Subgraph : public runtime::Object {
/*! \brief The graph. */
GraphPtr graph;
/*!
......@@ -367,8 +367,14 @@ struct Subgraph {
* \note This is also a map from the new edge id to the edge id in the parent graph.
*/
IdArray induced_edges;
static constexpr const char* _type_key = "graph.Subgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
};
// Define SubgraphRef
DGL_DEFINE_OBJECT_REF(SubgraphRef, Subgraph);
} // namespace dgl
#endif // DGL_GRAPH_INTERFACE_H_
......@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
"""
v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self, v_array)
induced_edges = utils.toindex(rst(2))
gidx = rst(0)
return SubgraphIndex(gidx, self, v, induced_edges)
return _CAPI_DGLGraphVertexSubgraph(self, v_array)
def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs.
......@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
"""
e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
induced_nodes = utils.toindex(rst(1))
gidx = rst(0)
return SubgraphIndex(gidx, self, induced_nodes, e)
return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='scipy_adj')
def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):
......@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
"""
return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))
class SubgraphIndex(object):
"""Internal subgraph data structure.
@register_object('graph.Subgraph')
class SubgraphIndex(ObjectBase):
"""Subgraph data structure"""
@property
def graph(self):
"""The subgraph structure
Parameters
----------
graph : GraphIndex
The graph structure of this subgraph.
parent : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, graph, parent, induced_nodes, induced_edges):
self.graph = graph
self.parent = parent
self.induced_nodes = induced_nodes
self.induced_edges = induced_edges
Returns
-------
GraphIndex
The subgraph
"""
return _CAPI_DGLSubgraphGetGraph(self)
def __getstate__(self):
raise NotImplementedError(
"SubgraphIndex pickling is not supported yet.")
@property
def induced_nodes(self):
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
def __setstate__(self, state):
raise NotImplementedError(
"SubgraphIndex unpickling is not supported yet.")
Returns
-------
list of utils.Index
Induced nodes
"""
ret = _CAPI_DGLSubgraphGetInducedVertices(self)
return utils.toindex(ret)
@property
def induced_edges(self):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret = _CAPI_DGLSubgraphGetInducedEdges(self)
return utils.toindex(ret)
###############################################################
......
......@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
namespace dgl {
namespace {
// Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
const int which = args[0];
if (which == 0) {
*rv = GraphRef(sg.graph);
} else if (which == 1) {
*rv = std::move(sg.induced_vertices);
} else if (which == 2) {
*rv = std::move(sg.induced_edges);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
} // namespace
///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
......@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = ConvertSubgraphToPackedFunc(g->VertexSubgraph(vids));
std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids)));
*rv = SubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
......@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
GraphRef g = args[0];
const IdArray eids = args[1];
bool preserve_nodes = args[2];
*rv = ConvertSubgraphToPackedFunc(g->EdgeSubgraph(eids, preserve_nodes));
std::shared_ptr<Subgraph> subg(
new Subgraph(g->EdgeSubgraph(eids, preserve_nodes)));
*rv = SubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
......@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
*rv = g->NumBits();
});
// Subgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = GraphRef(subg->graph);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = subg->induced_vertices;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = subg->induced_edges;
});
} // namespace dgl
......@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
return Subgraph{subcsr, vids, submat.data};
Subgraph subg;
subg.graph = subcsr;
subg.induced_vertices = vids;
subg.induced_edges = submat.data;
return subg;
}
CSRPtr CSR::Transpose() const {
......@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
COOPtr subcoo;
IdArray induced_nodes;
if (!preserve_nodes) {
IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids);
IdArray induced_nodes = aten::Relabel_({new_src, new_dst});
induced_nodes = aten::Relabel_({new_src, new_dst});
const auto new_nnodes = induced_nodes->shape[0];
COOPtr subcoo(new COO(new_nnodes, new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids};
subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
} else {
IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids);
IdArray induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
COOPtr subcoo(new COO(NumVertices(), new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids};
induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
}
Subgraph subg;
subg.graph = subcoo;
subg.induced_vertices = induced_nodes;
subg.induced_edges = eids;
return subg;
}
CSRPtr COO::ToCSR() const {
......@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
// We prefer to generate a subgraph from out-csr.
auto sg = GetOutCSR()->VertexSubgraph(vids);
CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
return Subgraph{GraphPtr(new ImmutableGraph(subcsr)),
sg.induced_vertices, sg.induced_edges};
sg.graph = GraphPtr(new ImmutableGraph(subcsr));
return sg;
}
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
return Subgraph{GraphPtr(new ImmutableGraph(subcoo)),
sg.induced_vertices, sg.induced_edges};
sg.graph = GraphPtr(new ImmutableGraph(subcoo));
return sg;
}
std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment