"src/array/vscode:/vscode.git/clone" did not exist on "2cf4bd0acf479d8d51347b6b524aebd3fdcc8d9f"
Commit 71e6e0cd authored by Minjie Wang's avatar Minjie Wang
Browse files

python shim + some simple tests

parent 758cb16e
...@@ -115,16 +115,18 @@ class Graph { ...@@ -115,16 +115,18 @@ class Graph {
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the predecessor id array. * \return the predecessor id array.
*/ */
IdArray Predecessors(dgl_id_t vid) const; IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const;
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the successor id array. * \return the successor id array.
*/ */
IdArray Successors(dgl_id_t vid) const; IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const;
/*! /*!
* \brief Get the id of the given edge. * \brief Get the id of the given edge.
...@@ -176,9 +178,10 @@ class Graph { ...@@ -176,9 +178,10 @@ class Graph {
/*! /*!
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \param sorted Whether the returned edge list is sorted by their edge ids.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
std::pair<IdArray, IdArray> Edges() const; std::pair<IdArray, IdArray> Edges(bool sorted = false) const;
/*! /*!
* \brief Get the in degree of the given vertex. * \brief Get the in degree of the given vertex.
......
...@@ -22,10 +22,10 @@ class DGLGraph(object): ...@@ -22,10 +22,10 @@ class DGLGraph(object):
v = utils.Index(v) v = utils.Index(v)
u_array = F.asdglarray(u.totensor()) u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor()) v_array = F.asdglarray(v.totensor())
_CAPI_DGLGraphAddEdges( _CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
self._handle,
u_array, def clear(self):
v_array) _CAPI_DGLGraphClear(self._handle)
def number_of_nodes(self): def number_of_nodes(self):
return _CAPI_DGLGraphNumVertices(self._handle) return _CAPI_DGLGraphNumVertices(self._handle)
...@@ -33,4 +33,76 @@ class DGLGraph(object): ...@@ -33,4 +33,76 @@ class DGLGraph(object):
def number_of_edges(self): def number_of_edges(self):
return _CAPI_DGLGraphNumEdges(self._handle) return _CAPI_DGLGraphNumEdges(self._handle)
def has_vertex(self, vid):
return _CAPI_DGLGraphHasVertex(self._handle, vid)
def has_vertices(self, vids):
vids = utils.Index(vids)
vid_array = F.asdglarray(vids.totensor())
return _CAPI_DGLGraphHasVertices(self._handle, vid_array)
def has_edge(self, u, v):
return _CAPI_DGLGraphHasEdge(self._handle, u, v)
def has_edges(self, u, v):
u = utils.Index(u)
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphHasEdges(self._handle, u_array, v_array)
def predecessors(self, v, radius=1):
return _CAPI_DGLGraphPredecessors(self._handle, v, radius)
def successors(self, v, radius=1):
return _CAPI_DGLGraphSuccessors(self._handle, v, radius)
def edge_id(self, u, v):
return _CAPI_DGLGraphEdgeId(self._handle, u, v)
def edge_ids(self, u, v):
u = utils.Index(u)
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array)
def inedges(self, v):
if isinstance(v, int):
pair = _CAPI_DGLGraphInEdges_1(self._handle, v)
else:
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
pair = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
return pair(0), pair(1)
def outedges(self, v):
if isinstance(v, int):
pair = _CAPI_DGLGraphOutEdges_1(self._handle, v)
else:
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
pair = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
return pair(0), pair(1)
def indegree(self, v):
return _CAPI_DGLGraphInDegree(self._handle, v)
def indegrees(self, v):
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphInDegrees(self._handle, v_array)
def outdegree(self, v):
return _CAPI_DGLGraphOutDegree(self._handle, v)
def outdegrees(self, v):
v = utils.Index(v)
v_array = F.asdglarray(v.totensor())
return _CAPI_DGLGraphOutDegrees(self._handle, v_array)
def edges(self, sorted=False):
pair = _CAPI_DGLGraphEdges(self._handle, sorted)
return pair(0), pair(1)
_init_api("dgl.cgraph") _init_api("dgl.cgraph")
...@@ -104,8 +104,9 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const { ...@@ -104,8 +104,9 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
} }
// The data is copy-out; support zero-copy? // The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid) const { IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& pred = adjlist_[vid].pred; const auto& pred = adjlist_[vid].pred;
const int64_t len = pred.size(); const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -117,8 +118,9 @@ IdArray Graph::Predecessors(dgl_id_t vid) const { ...@@ -117,8 +118,9 @@ IdArray Graph::Predecessors(dgl_id_t vid) const {
} }
// The data is copy-out; support zero-copy? // The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid) const { IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& succ = adjlist_[vid].succ; const auto& succ = adjlist_[vid].succ;
const int64_t len = succ.size(); const int64_t len = succ.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -241,9 +243,13 @@ std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const { ...@@ -241,9 +243,13 @@ std::pair<IdArray, IdArray> Graph::OutEdges(IdArray vids) const {
return std::make_pair(src, dst); return std::make_pair(src, dst);
} }
// O(E*log(E)) due to sorting // O(E*log(E)) if sort is required; otherwise, O(E)
std::pair<IdArray, IdArray> Graph::Edges() const { std::pair<IdArray, IdArray> Graph::Edges(bool sorted) const {
const int64_t len = num_edges_; const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
if (sorted) {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple; typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples; std::vector<Tuple> tuples;
tuples.reserve(len); tuples.reserve(len);
...@@ -259,14 +265,22 @@ std::pair<IdArray, IdArray> Graph::Edges() const { ...@@ -259,14 +265,22 @@ std::pair<IdArray, IdArray> Graph::Edges() const {
}); });
// make return arrays // make return arrays
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data); int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data); int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
src_ptr[i] = std::get<0>(tuples[i]); src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]); dst_ptr[i] = std::get<1>(tuples[i]);
} }
} else {
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
for (dgl_id_t u = 0; u < NumVertices(); ++u) {
for (size_t i = 0; i < adjlist_[u].succ.size(); ++i) {
*(src_ptr++) = u;
*(dst_ptr++) = adjlist_[u].succ[i];
}
}
}
return std::make_pair(src, dst); return std::make_pair(src, dst);
} }
......
...@@ -8,74 +8,222 @@ using tvm::runtime::TVMRetValue; ...@@ -8,74 +8,222 @@ using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc; using tvm::runtime::PackedFunc;
namespace dgl { namespace dgl {
namespace {
template<typename T1, typename T2>
PackedFunc ConvertPairToPackedFunc(const std::pair<T1, T2>& pair) {
auto body = [pair] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
*rv = which? pair.second : pair.first;
};
return PackedFunc(body);
}
} // namespace
// Graph handler type
typedef void* GraphHandle; typedef void* GraphHandle;
void DGLGraphCreate(TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphCreate")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph(); GraphHandle ghandle = new Graph();
*rv = ghandle; *rv = ghandle;
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphCreate")
.set_body(DGLGraphCreate);
void DGLGraphFree(TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphFree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr; delete gptr;
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphFree") TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddVertices")
.set_body(DGLGraphFree); .set_body([] (TVMArgs args, TVMRetValue* rv) {
void DGLGraphAddVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1]; uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices); gptr->AddVertices(num_vertices);
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddVertices")
.set_body(DGLGraphAddVertices);
void DGLGraphAddEdge(TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst); gptr->AddEdge(src, dst);
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdge") TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdges")
.set_body(DGLGraphAddEdge); .set_body([] (TVMArgs args, TVMRetValue* rv) {
void DGLGraphAddEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
gptr->AddEdges(src, dst); gptr->AddEdges(src, dst);
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdges") TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphClear")
.set_body(DGLGraphAddEdges); .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
gptr->Clear();
});
void DGLGraphNumVertices(TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices()); *rv = static_cast<int64_t>(gptr->NumVertices());
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumVertices") TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumEdges")
.set_body(DGLGraphNumVertices); .set_body([] (TVMArgs args, TVMRetValue* rv) {
void DGLGraphNumEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges()); *rv = static_cast<int64_t>(gptr->NumEdges());
} });
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumEdges") TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasVertex")
.set_body(DGLGraphNumEdges); .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = gptr->HasVertices(vids);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->HasEdge(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphHasEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
*rv = gptr->HasEdges(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphPredecessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Predecessors(vid, radius);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphSuccessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Successors(vid, radius);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdgeId")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = static_cast<int64_t>(gptr->EdgeId(src, dst));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdgeIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
*rv = gptr->EdgeIds(src, dst);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertPairToPackedFunc(gptr->InEdges(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = ConvertPairToPackedFunc(gptr->InEdges(vids));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertPairToPackedFunc(gptr->OutEdges(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = ConvertPairToPackedFunc(gptr->OutEdges(vids));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1];
*rv = ConvertPairToPackedFunc(gptr->Edges(sorted));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphInDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = gptr->InDegrees(vids);
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid));
});
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphOutDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = args[1];
*rv = gptr->OutDegrees(vids);
});
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment