Commit dba79f1d authored by Minjie Wang's avatar Minjie Wang
Browse files

also return edge id for many APIs

parent 71e6e0cd
......@@ -33,6 +33,12 @@ class Graph;
*/
class Graph {
public:
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;
/*! \brief default constructor */
Graph() {}
......@@ -129,7 +135,7 @@ class Graph {
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Get the id of the given edge.
* \brief Get the edge id using the two endpoints
* \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph.
* \param src The source vertex.
......@@ -139,7 +145,7 @@ class Graph {
dgl_id_t EdgeId(dgl_id_t src, dgl_id_t dst) const;
/*!
* \brief Get the id of the given edges.
* \brief Get the edge id using the two endpoints
* \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph.
* \return the edge id array.
......@@ -150,16 +156,16 @@ class Graph {
* \brief Get the in edges of the vertex.
* \note The returned dst id array is filled with vid.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
* \return the edges
*/
std::pair<IdArray, IdArray> InEdges(dgl_id_t vid) const;
EdgeArray InEdges(dgl_id_t vid) const;
/*!
* \brief Get the in edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> InEdges(IdArray vids) const;
EdgeArray InEdges(IdArray vids) const;
/*!
* \brief Get the out edges of the vertex.
......@@ -167,21 +173,22 @@ class Graph {
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> OutEdges(dgl_id_t vid) const;
EdgeArray OutEdges(dgl_id_t vid) const;
/*!
* \brief Get the out edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> OutEdges(IdArray vids) const;
EdgeArray OutEdges(IdArray vids) const;
/*!
* \brief Get all the edges in the graph.
* \note If sorted is true, the id array is not returned.
* \param sorted Whether the returned edge list is sorted by their edge ids.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> Edges(bool sorted = false) const;
EdgeArray Edges(bool sorted = false) const;
/*!
* \brief Get the in degree of the given vertex.
......@@ -262,6 +269,10 @@ class Graph {
*/
Graph Reverse() const;
static Graph Merge(std::vector<Graph> graphs);
std::vector<Graph> Split(std::vector<IdArray> vids_array) const;
private:
/*! \brief Internal edge list type */
struct EdgeList {
......@@ -270,7 +281,9 @@ class Graph {
/*! \brief predecessor vertex list */
vector_view<dgl_id_t> pred;
/*! \brief (local) succ edge id property */
std::vector<dgl_id_t> edge_id;
std::vector<dgl_id_t> succ_edge_id;
/*! \brief (local) pred edge id property */
std::vector<dgl_id_t> pred_edge_id;
};
/*! \brief Adjacency list using vector storage */
// TODO(minjie): adjacent list is good for graph mutation and finding pred/succ.
......
......@@ -67,42 +67,51 @@ class DGLGraph(object):
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array)
def inedges(self, v):
def in_edges(self, v):
if isinstance(v, int):
pair = _CAPI_DGLGraphInEdges_1(self._handle, v)
edge_array = _CAPI_DGLGraphInEdges_1(self._handle, v)
else:
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
pair = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
return pair(0), pair(1)
edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
return src, dst, eid
def outedges(self, v):
def out_edges(self, v):
if isinstance(v, int):
pair = _CAPI_DGLGraphOutEdges_1(self._handle, v)
edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, v)
else:
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
pair = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
return pair(0), pair(1)
edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
return src, dst, eid
def indegree(self, v):
def edges(self, sorted=False):
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
src = edge_array(0)
dst = edge_array(1)
eid = edge_array(2)
return src, dst, eid
def in_degree(self, v):
return _CAPI_DGLGraphInDegree(self._handle, v)
def indegrees(self, v):
def in_degrees(self, v):
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphInDegrees(self._handle, v_array)
def outdegree(self, v):
def out_degree(self, v):
return _CAPI_DGLGraphOutDegree(self._handle, v)
def outdegrees(self, v):
def out_degrees(self, v):
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphOutDegrees(self._handle, v_array)
def edges(self, sorted=False):
pair = _CAPI_DGLGraphEdges(self._handle, sorted)
return pair(0), pair(1)
_init_api("dgl.cgraph")
......@@ -21,8 +21,9 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
<< "In valid vertices: " << src << " " << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
adjlist_[src].succ_edge_id.push_back(eid);
adjlist_[dst].pred.push_back(src);
adjlist_[dst].pred_edge_id.push_back(eid);
}
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
......@@ -137,7 +138,7 @@ dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
const auto& succ = adjlist_[src].succ;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) {
return adjlist_[src].edge_id[i];
return adjlist_[src].succ_edge_id[i];
}
}
LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
......@@ -176,17 +177,25 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
}
// O(E)
std::pair<IdArray, IdArray> Graph::InEdges(dgl_id_t vid) const {
const auto& src = Predecessors(vid);
const auto srclen = src->shape[0];
IdArray dst = IdArray::Empty({srclen}, src->dtype, src->ctx);
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].pred.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
std::fill(dst_data, dst_data + srclen, vid);
return std::make_pair(src, dst);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
src_data[i] = adjlist_[vid].pred[i];
eid_data[i] = adjlist_[vid].pred_edge_id[i];
}
std::fill(dst_data, dst_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
std::pair<IdArray, IdArray> Graph::InEdges(IdArray vids) const {
Graph::EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -197,30 +206,42 @@ std::pair<IdArray, IdArray> Graph::InEdges(IdArray vids) const {
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& pred = adjlist_[vid_data[i]].pred;
const auto& eids = adjlist_[vid_data[i]].pred_edge_id;
for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i];
*(eid_ptr++) = eids[j];
}
}
return std::make_pair(src, dst);
return EdgeArray{src, dst, eid};
}
// O(E)
std::pair<IdArray, IdArray> Graph::OutEdges(dgl_id_t vid) const {
const auto& dst = Successors(vid);
const auto dstlen = dst->shape[0];
IdArray src = IdArray::Empty({dstlen}, dst->dtype, dst->ctx);
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
std::fill(src_data, src_data + dstlen, vid);
return std::make_pair(src, dst);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
dst_data[i] = adjlist_[vid].succ[i];
eid_data[i] = adjlist_[vid].succ_edge_id[i];
}
std::fill(src_data, src_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const {
Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -231,23 +252,28 @@ std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const {
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ;
const auto& eids = adjlist_[vid_data[i]].succ_edge_id;
for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j];
*(eid_ptr++) = eids[j];
}
}
return std::make_pair(src, dst);
return EdgeArray{src, dst, eid};
}
// O(E*log(E)) if sort is required; otherwise, O(E)
std::pair<IdArray, IdArray> Graph::Edges(bool sorted) const {
Graph::EdgeArray Graph::Edges(bool sorted) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
if (sorted) {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
......@@ -255,7 +281,7 @@ std::pair<IdArray, IdArray> Graph::Edges(bool sorted) const {
tuples.reserve(len);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
tuples.push_back(std::make_tuple(u, adjlist_[u].succ[i], adjlist_[u].edge_id[i]));
tuples.emplace_back(u, adjlist_[u].succ[i], adjlist_[u].succ_edge_id[i]);
}
}
// sort according to edge ids
......@@ -267,22 +293,26 @@ std::pair<IdArray, IdArray> Graph::Edges(bool sorted) const {
// make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
eid_ptr[i] = std::get<2>(tuples[i]);
}
} else {
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
*(src_ptr++) = u;
*(dst_ptr++) = adjlist_[u].succ[i];
*(eid_ptr++) = adjlist_[u].succ_edge_id[i];
}
}
}
return std::make_pair(src, dst);
return EdgeArray{src, dst, eid};
}
// O(V)
......
......@@ -10,11 +10,18 @@ using tvm::runtime::PackedFunc;
namespace dgl {
namespace {
template<typename T1, typename T2>
PackedFunc ConvertPairToPackedFunc(const std::pair<T1, T2>& pair) {
auto body = [pair] (TVMArgs args, TVMRetValue* rv) {
PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) {
auto body = [ea] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
*rv = which? pair.second : pair.first;
if (which == 0) {
*rv = ea.src;
} else if (which == 1) {
*rv = ea.dst;
} else if (which == 2) {
*rv = ea.id;
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
......@@ -159,7 +166,7 @@ TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_1")
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertPairToPackedFunc(gptr->InEdges(vid));
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_2")
......@@ -167,7 +174,7 @@ TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_2")
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = ConvertPairToPackedFunc(gptr->InEdges(vids));
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_1")
......@@ -175,7 +182,7 @@ TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_1")
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertPairToPackedFunc(gptr->OutEdges(vid));
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_2")
......@@ -183,7 +190,7 @@ TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_2")
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = ConvertPairToPackedFunc(gptr->OutEdges(vids));
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdges")
......@@ -191,7 +198,7 @@ TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdges")
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1];
*rv = ConvertPairToPackedFunc(gptr->Edges(sorted));
*rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(sorted));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInDegree")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment