Commit 71e6e0cd authored by Minjie Wang's avatar Minjie Wang
Browse files

python shim + some simple tests

parent 758cb16e
......@@ -115,16 +115,18 @@ class Graph {
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid) const;
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid) const;
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Get the id of the given edge.
......@@ -176,9 +178,10 @@ class Graph {
/*!
* \brief Get all the edges in the graph.
* \param sorted Whether the returned edge list is sorted by their edge ids.
* \return the id arrays of the two endpoints of the edges.
*/
std::pair<IdArray, IdArray> Edges() const;
std::pair<IdArray, IdArray> Edges(bool sorted = false) const;
/*!
* \brief Get the in degree of the given vertex.
......
......@@ -22,10 +22,10 @@ class DGLGraph(object):
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
_CAPI_DGLGraphAddEdges(
self._handle,
u_array,
v_array)
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
def clear(self):
_CAPI_DGLGraphClear(self._handle)
def number_of_nodes(self):
return _CAPI_DGLGraphNumVertices(self._handle)
......@@ -33,4 +33,76 @@ class DGLGraph(object):
def number_of_edges(self):
return _CAPI_DGLGraphNumEdges(self._handle)
def has_vertex(self, vid):
return _CAPI_DGLGraphHasVertex(self._handle, vid)
def has_vertices(self, vids):
vids = utils.Index(vids)
vid_array = F.asdglarray(vids.totensor())
return _CAPI_DGLGraphHasVertices(self._handle, vid_array)
def has_edge(self, u, v):
return _CAPI_DGLGraphHasEdge(self._handle, u, v)
def has_edges(self, u, v):
u = utils.Index(u)
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphHasEdges(self._handle, u_array, v_array)
def predecessors(self, v, radius=1):
return _CAPI_DGLGraphPredecessors(self._handle, v, radius)
def successors(self, v, radius=1):
return _CAPI_DGLGraphSuccessors(self._handle, v, radius)
def edge_id(self, u, v):
return _CAPI_DGLGraphEdgeId(self._handle, u, v)
def edge_ids(self, u, v):
u = utils.Index(u)
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array)
def inedges(self, v):
if isinstance(v, int):
pair = _CAPI_DGLGraphInEdges_1(self._handle, v)
else:
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
pair = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
return pair(0), pair(1)
def outedges(self, v):
if isinstance(v, int):
pair = _CAPI_DGLGraphOutEdges_1(self._handle, v)
else:
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
pair = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
return pair(0), pair(1)
def indegree(self, v):
return _CAPI_DGLGraphInDegree(self._handle, v)
def indegrees(self, v):
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphInDegrees(self._handle, v_array)
def outdegree(self, v):
return _CAPI_DGLGraphOutDegree(self._handle, v)
def outdegrees(self, v):
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphOutDegrees(self._handle, v_array)
def edges(self, sorted=False):
pair = _CAPI_DGLGraphEdges(self._handle, sorted)
return pair(0), pair(1)
_init_api("dgl.cgraph")
......@@ -104,8 +104,9 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
}
// The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid) const {
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& pred = adjlist_[vid].pred;
const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -117,8 +118,9 @@ IdArray Graph::Predecessors(dgl_id_t vid) const {
}
// The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid) const {
IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& succ = adjlist_[vid].succ;
const int64_t len = succ.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -241,31 +243,43 @@ std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const {
return std::make_pair(src, dst);
}
// O(E*log(E)) due to sorting
std::pair<IdArray, IdArray> Graph::Edges() const {
// O(E*log(E)) if sort is required; otherwise, O(E)
std::pair<IdArray, IdArray> Graph::Edges(bool sorted) const {
const int64_t len = num_edges_;
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
tuples.push_back(std::make_tuple(u, adjlist_[u].succ[i], adjlist_[u].edge_id[i]));
}
}
// sort according to edge ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<2>(t1) < std::get<2>(t2);
});
// make return arrays
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
if (sorted) {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
tuples.push_back(std::make_tuple(u, adjlist_[u].succ[i], adjlist_[u].edge_id[i]));
}
}
// sort according to edge ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<2>(t1) < std::get<2>(t2);
});
// make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
}
} else {
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
*(src_ptr++) = u;
*(dst_ptr++) = adjlist_[u].succ[i];
}
}
}
return std::make_pair(src, dst);
......
......@@ -8,74 +8,222 @@ using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
namespace dgl {
namespace {
template<typename T1, typename T2>
PackedFunc ConvertPairToPackedFunc(const std::pair<T1, T2>& pair) {
auto body = [pair] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
*rv = which? pair.second : pair.first;
};
return PackedFunc(body);
}
typedef void* GraphHandle;
} // namespace
void DGLGraphCreate(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph();
*rv = ghandle;
}
// Graph handler type
typedef void* GraphHandle;
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphCreate")
.set_body(DGLGraphCreate);
void DGLGraphFree(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
}
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph();
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphFree")
.set_body(DGLGraphFree);
void DGLGraphAddVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
}
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddVertices")
.set_body(DGLGraphAddVertices);
void DGLGraphAddEdge(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
}
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdge")
.set_body(DGLGraphAddEdge);
void DGLGraphAddEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
gptr->AddEdges(src, dst);
}
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdges")
.set_body(DGLGraphAddEdges);
void DGLGraphNumVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
}
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
gptr->AddEdges(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphClear")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
gptr->Clear();
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumVertices")
.set_body(DGLGraphNumVertices);
void DGLGraphNumEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
}
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumEdges")
.set_body(DGLGraphNumEdges);
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasVertex")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = gptr->HasVertices(vids);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->HasEdge(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
*rv = gptr->HasEdges(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphPredecessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Predecessors(vid, radius);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphSuccessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Successors(vid, radius);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdgeId")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = static_cast<int64_t>(gptr->EdgeId(src, dst));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdgeIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
*rv = gptr->EdgeIds(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertPairToPackedFunc(gptr->InEdges(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = ConvertPairToPackedFunc(gptr->InEdges(vids));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertPairToPackedFunc(gptr->OutEdges(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = ConvertPairToPackedFunc(gptr->OutEdges(vids));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1];
*rv = ConvertPairToPackedFunc(gptr->Edges(sorted));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = gptr->InDegrees(vids);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = gptr->OutDegrees(vids);
});
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment