Unverified Commit 67dc1197 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Use object system for all CAPIs (#716)

* WIP: using object system for graph

* c++ side refactoring done; compiled

* remove stale apis

* fix bug in DGLGraphCreate; passed test_graph.py

* fix bug in python modify; passed utest for pytorch/cpu

* fix lint

* address comments
parent b0d9e7aa
......@@ -25,25 +25,4 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
return PackedFunc(body);
}
DGL_REGISTER_GLOBAL("_GetVectorWrapperSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const CAPIVectorWrapper* wrapper = static_cast<const CAPIVectorWrapper*>(ptr);
*rv = static_cast<int64_t>(wrapper->pointers.size());
});
DGL_REGISTER_GLOBAL("_GetVectorWrapperData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
*rv = static_cast<void*>(wrapper->pointers.data());
});
DGL_REGISTER_GLOBAL("_FreeVectorWrapper")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
delete wrapper;
});
} // namespace dgl
......@@ -34,9 +34,6 @@ inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
namespace dgl {
// Graph handler type
typedef void* GraphHandle;
// Communicator handler type
typedef void* CommunicatorHandle;
......@@ -73,29 +70,6 @@ dgl::runtime::NDArray CopyVectorToNDArray(
return a;
}
/* A structure used to return a vector of void* pointers. */
struct CAPIVectorWrapper {
// The pointer vector.
std::vector<void*> pointers;
};
/*!
* \brief A helper function used to return vector of pointers from C to frontend.
*
* Note that the function will move the given vector memory into the returned
* wrapper object.
*
* \param vec The given pointer vectors.
* \return A wrapper object containing the given data.
*/
template<typename PType>
CAPIVectorWrapper* WrapVectorReturn(std::vector<PType*> vec) {
CAPIVectorWrapper* wrapper = new CAPIVectorWrapper;
wrapper->pointers.reserve(vec.size());
wrapper->pointers.insert(wrapper->pointers.end(), vec.begin(), vec.end());
return wrapper;
}
} // namespace dgl
#endif // DGL_C_API_COMMON_H_
......@@ -200,7 +200,7 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
}
// O(E*k) pretty slow
Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
......@@ -246,7 +246,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
return EdgeArray{rst_src, rst_dst, rst_eid};
}
Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0];
......@@ -272,7 +272,7 @@ Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
}
// O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -290,7 +290,7 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
}
// O(E)
Graph::EdgeArray Graph::InEdges(IdArray vids) const {
EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -318,7 +318,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
}
// O(E)
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -336,7 +336,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
}
// O(E)
Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -364,7 +364,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
}
// O(E*log(E)) if sort is required; otherwise, O(E)
Graph::EdgeArray Graph::Edges(const std::string &order) const {
EdgeArray Graph::Edges(const std::string &order) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -585,9 +585,4 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
}
}
GraphPtr Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return nullptr;
}
} // namespace dgl
......@@ -3,6 +3,7 @@
* \file graph/graph.cc
* \brief DGL graph index APIs
*/
#include <dgl/packed_func_ext.h>
#include <dgl/graph.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_op.h>
......@@ -55,9 +56,7 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
const int which = args[0];
if (which == 0) {
GraphInterface* gptr = sg.graph->Reset();
GraphHandle ghandle = gptr;
*rv = ghandle;
*rv = GraphRef(sg.graph);
} else if (which == 1) {
*rv = std::move(sg.induced_vertices);
} else if (which == 2) {
......@@ -71,60 +70,12 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
} // namespace
namespace {
// This namespace contains template functions for batching
// and unbatching over graph and immutable graph
template<typename T>
void DGLDisjointPartitionByNum(const T *gptr, DGLArgs args, DGLRetValue *rv) {
int64_t num = args[1];
std::vector<T> &&rst = GraphOp::DisjointPartitionByNum(gptr, num);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *ptr_array_data = static_cast<int64_t *>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
GraphInterface *ptr = rst[i].Reset();
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
}
template<typename T>
void DGLDisjointUnion(GraphHandle *inhandles, int list_size, DGLRetValue *rv) {
std::vector<const T *> graphs;
for (int i = 0; i < list_size; ++i) {
const GraphInterface *ptr = static_cast<const GraphInterface *>(inhandles[i]);
const T *gr = dynamic_cast<const T *>(ptr);
CHECK(gr) << "Error: Attempted to batch MutableGraph with ImmutableGraph";
graphs.push_back(gr);
}
GraphHandle ghandle = GraphOp::DisjointUnion(std::move(graphs)).Reset();
*rv = ghandle;
}
template<typename T>
void DGLDisjointPartitionBySizes(const T *gptr, const IdArray sizes, DGLRetValue *rv) {
std::vector<T> &&rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *ptr_array_data = static_cast<int64_t *>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
GraphInterface *ptr = rst[i].Reset();
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
}
} // namespace
///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
bool multigraph = static_cast<bool>(args[0]);
GraphHandle ghandle = new Graph(multigraph);
*rv = ghandle;
bool multigraph = args[0];
*rv = GraphRef(Graph::Create(multigraph));
});
......@@ -135,20 +86,16 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
const int multigraph = args[2];
const int64_t num_nodes = args[3];
const bool readonly = args[4];
GraphHandle ghandle;
if (readonly) {
if (multigraph == kBoolUnknown) {
COOPtr coo(new COO(num_nodes, src_ids, dst_ids));
ghandle = new ImmutableGraph(coo);
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
} else {
COOPtr coo(new COO(num_nodes, src_ids, dst_ids, multigraph));
ghandle = new ImmutableGraph(coo);
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids, multigraph));
}
} else {
CHECK_NE(multigraph, kBoolUnknown);
ghandle = new Graph(src_ids, dst_ids, num_nodes, multigraph);
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids, multigraph));
}
*rv = ghandle;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
......@@ -164,26 +111,22 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
ImmutableGraph *g = nullptr;
if (shared_mem_name.empty()) {
if (multigraph == kBoolUnknown) {
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
edge_dir));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
} else {
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, multigraph, edge_dir));
}
} else {
if (multigraph == kBoolUnknown) {
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
edge_dir, shared_mem_name));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, edge_dir, shared_mem_name));
} else {
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir,
shared_mem_name));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir, shared_mem_name));
}
}
*rv = g;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
......@@ -194,407 +137,229 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
const bool multigraph = args[3];
const std::string edge_dir = args[4];
// TODO(minjie): how to know multigraph
GraphHandle ghandle = new ImmutableGraph(ImmutableGraph::CreateFromCSR(
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
shared_mem_name, num_vertices, num_edges, multigraph, edge_dir));
*rv = ghandle;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
delete gptr;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
g->AddVertices(num_vertices);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
g->AddEdge(src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
const IdArray dst = args[2];
gptr->AddEdges(src, dst);
g->AddEdges(src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
gptr->Clear();
GraphRef g = args[0];
g->Clear();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
GraphHandle ghandle = args[0];
// NOTE: not const since we have caches
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
*rv = gptr->IsMultigraph();
GraphRef g = args[0];
*rv = g->IsMultigraph();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
GraphHandle ghandle = args[0];
// NOTE: not const since we have caches
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
*rv = gptr->IsReadonly();
GraphRef g = args[0];
*rv = g->IsReadonly();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumVertices());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumEdges());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid);
*rv = g->HasVertex(vid);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = gptr->HasVertices(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray parent_vids = args[0];
const IdArray query = args[1];
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
*rv = g->HasVertices(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->HasEdgeBetween(src, dst);
*rv = g->HasEdgeBetween(src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
const IdArray dst = args[2];
*rv = gptr->HasEdgesBetween(src, dst);
*rv = g->HasEdgesBetween(src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Predecessors(vid, radius);
*rv = g->Predecessors(vid, radius);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Successors(vid, radius);
*rv = g->Successors(vid, radius);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->EdgeId(src, dst);
*rv = g->EdgeId(src, dst);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
const IdArray dst = args[2];
*rv = ConvertEdgeArrayToPackedFunc(gptr->EdgeIds(src, dst));
*rv = ConvertEdgeArrayToPackedFunc(g->EdgeIds(src, dst));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray eids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->FindEdges(eids));
*rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
std::string order = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(order));
*rv = ConvertEdgeArrayToPackedFunc(g->Edges(order));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid));
*rv = static_cast<int64_t>(g->InDegree(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = gptr->InDegrees(vids);
*rv = g->InDegrees(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid));
*rv = static_cast<int64_t>(g->OutDegree(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = gptr->OutDegrees(vids);
*rv = g->OutDegrees(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
*rv = ConvertSubgraphToPackedFunc(g->VertexSubgraph(vids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *gptr = static_cast<GraphInterface*>(ghandle);
GraphRef g = args[0];
const IdArray eids = args[1];
bool preserve_nodes = args[2];
*rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids, preserve_nodes));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* list = args[0];
GraphHandle* inhandles = static_cast<GraphHandle*>(list);
int list_size = args[1];
const GraphInterface *ptr = static_cast<const GraphInterface *>(inhandles[0]);
const ImmutableGraph *im_gr = dynamic_cast<const ImmutableGraph *>(ptr);
const Graph *gr = dynamic_cast<const Graph *>(ptr);
if (gr) {
DGLDisjointUnion<Graph>(inhandles, list_size, rv);
} else {
CHECK(im_gr) << "Args[0] is not a list of valid DGLGraph";
DGLDisjointUnion<ImmutableGraph>(inhandles, list_size, rv);
}
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const Graph* gptr = dynamic_cast<const Graph*>(ptr);
const ImmutableGraph* im_gptr = dynamic_cast<const ImmutableGraph*>(ptr);
if (gptr) {
DGLDisjointPartitionByNum(gptr, args, rv);
} else {
CHECK(im_gptr) << "Args[0] is not a valid DGLGraph";
DGLDisjointPartitionByNum(im_gptr, args, rv);
}
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const IdArray sizes = args[1];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const Graph* gptr = dynamic_cast<const Graph*>(ptr);
const ImmutableGraph* im_gptr = dynamic_cast<const ImmutableGraph*>(ptr);
if (gptr) {
DGLDisjointPartitionBySizes(gptr, sizes, rv);
} else {
CHECK(im_gptr) << "Args[0] is not a valid DGLGraph";
DGLDisjointPartitionBySizes(im_gptr, sizes, rv);
}
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
bool backtracking = args[1];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const Graph* gptr = dynamic_cast<const Graph*>(ptr);
CHECK(gptr) << "_CAPI_DGLGraphLineGraph isn't implemented in immutable graph";
Graph* lgptr = new Graph();
*lgptr = GraphOp::LineGraph(gptr, backtracking);
GraphHandle lghandle = lgptr;
*rv = lghandle;
*rv = ConvertSubgraphToPackedFunc(g->EdgeSubgraph(eids, preserve_nodes));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
bool transpose = args[1];
std::string format = args[2];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
auto res = ptr->GetAdj(transpose, format);
auto res = g->GetAdj(transpose, format);
*rv = ConvertAdjToPackedFunc(res);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLToImmutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
GraphHandle newhandle = new ImmutableGraph(ImmutableGraph::ToImmutable(ptr));
*rv = newhandle;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
*rv = ptr->Context();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const int device_type = args[1];
const int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
const ImmutableGraph *ig = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(ig) << "Invalid argument: must be an immutable graph object.";
GraphHandle newhandle = new ImmutableGraph(ig->CopyTo(ctx));
*rv = newhandle;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
std::string edge_dir = args[1];
std::string name = args[2];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
const ImmutableGraph *ig = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(ig) << "Invalid argument: must be an immutable graph object.";
GraphHandle newhandle = new ImmutableGraph(ig->CopyToSharedMem(edge_dir, name));
*rv = newhandle;
GraphRef g = args[0];
*rv = g->Context();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
*rv = ptr->NumBits();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
int bits = args[1];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
const ImmutableGraph *ig = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(ig) << "Invalid argument: must be an immutable graph object.";
GraphHandle newhandle = new ImmutableGraph(ig->AsNumBits(bits));
*rv = newhandle;
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
GraphHandle ret = GraphOp::ToSimpleGraph(ptr).Reset();
*rv = ret;
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
Graph* bgptr = new Graph();
*bgptr = GraphOp::ToBidirectedMutableGraph(ptr);
GraphHandle bghandle = bgptr;
*rv = bghandle;
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
GraphHandle bghandle = GraphOp::ToBidirectedImmutableGraph(ptr).Reset();
*rv = bghandle;
GraphRef g = args[0];
*rv = g->NumBits();
});
} // namespace dgl
......@@ -5,9 +5,13 @@
*/
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <algorithm>
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace {
// generate consecutive dgl ids
......@@ -38,199 +42,208 @@ class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {
private:
dgl_id_t cur_;
};
} // namespace
Graph GraphOp::LineGraph(const Graph* g, bool backtracking) {
Graph lg;
lg.AddVertices(g->NumEdges());
for (size_t i = 0; i < g->all_edges_src_.size(); ++i) {
const auto u = g->all_edges_src_[i];
const auto v = g->all_edges_dst_[i];
for (size_t j = 0; j < g->adjlist_[v].succ.size(); ++j) {
if (backtracking || (!backtracking && g->adjlist_[v].succ[j] != u)) {
lg.AddEdge(i, g->adjlist_[v].edge_id[j]);
}
}
}
return lg;
bool IsMutable(GraphPtr g) {
MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(g);
return mg != nullptr;
}
Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
Graph rst;
uint64_t cumsum = 0;
for (const Graph* gr : graphs) {
rst.AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
rst.AddEdge(gr->all_edges_src_[i] + cumsum, gr->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
return rst;
}
} // namespace
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
GraphPtr GraphOp::Reverse(GraphPtr g) {
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g);
CHECK(ig) << "Reverse is only supported on immutable graph";
return ig->Reverse();
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
GraphPtr GraphOp::LineGraph(GraphPtr g, bool backtracking) {
MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(g);
CHECK(mg) << "Line graph transformation is only supported on mutable graph";
MutableGraphPtr lg = Graph::Create();
lg->AddVertices(g->NumEdges());
for (size_t i = 0; i < mg->all_edges_src_.size(); ++i) {
const auto u = mg->all_edges_src_[i];
const auto v = mg->all_edges_dst_[i];
for (size_t j = 0; j < mg->adjlist_[v].succ.size(); ++j) {
if (backtracking || (!backtracking && mg->adjlist_[v].succ[j] != u)) {
lg->AddEdge(i, mg->adjlist_[v].edge_id[j]);
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
return rst;
return lg;
}
ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs) {
int64_t num_nodes = 0;
int64_t num_edges = 0;
for (const ImmutableGraph *gr : graphs) {
num_nodes += gr->NumVertices();
num_edges += gr->NumEdges();
}
IdArray indptr_arr = aten::NewIdArray(num_nodes + 1);
IdArray indices_arr = aten::NewIdArray(num_edges);
IdArray edge_ids_arr = aten::NewIdArray(num_edges);
dgl_id_t* indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
indptr[0] = 0;
dgl_id_t cum_num_nodes = 0;
dgl_id_t cum_num_edges = 0;
for (const ImmutableGraph *gr : graphs) {
const CSRPtr g_csrptr = gr->GetInCSR();
const int64_t g_num_nodes = g_csrptr->NumVertices();
const int64_t g_num_edges = g_csrptr->NumEdges();
dgl_id_t* g_indptr = static_cast<dgl_id_t*>(g_csrptr->indptr()->data);
dgl_id_t* g_indices = static_cast<dgl_id_t*>(g_csrptr->indices()->data);
dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(g_csrptr->edge_ids()->data);
for (dgl_id_t i = 1; i < g_num_nodes + 1; ++i) {
indptr[cum_num_nodes + i] = g_indptr[i] + cum_num_edges;
GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
CHECK_GT(graphs.size(), 0) << "Input graph list is empty";
if (IsMutable(graphs[0])) {
// Disjointly union of a list of mutable graph inputs. The result is
// also a mutable graph.
MutableGraphPtr rst = Graph::Create();
uint64_t cumsum = 0;
for (GraphPtr gr : graphs) {
MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(gr);
CHECK(mg) << "All the input graphs should be mutable graphs.";
rst->AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
// TODO(minjie): quite ugly to expose internal members
rst->AddEdge(mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
for (dgl_id_t i = 0; i < g_num_edges; ++i) {
indices[cum_num_edges + i] = g_indices[i] + cum_num_nodes;
return rst;
} else {
// Disjointly union of a list of immutable graph inputs. The result is
// also an immutable graph.
int64_t num_nodes = 0;
int64_t num_edges = 0;
for (auto gr : graphs) {
num_nodes += gr->NumVertices();
num_edges += gr->NumEdges();
}
IdArray indptr_arr = aten::NewIdArray(num_nodes + 1);
IdArray indices_arr = aten::NewIdArray(num_edges);
IdArray edge_ids_arr = aten::NewIdArray(num_edges);
dgl_id_t* indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
for (dgl_id_t i = 0; i < g_num_edges; ++i) {
edge_ids[cum_num_edges + i] = g_edge_ids[i] + cum_num_edges;
indptr[0] = 0;
dgl_id_t cum_num_nodes = 0;
dgl_id_t cum_num_edges = 0;
for (auto g : graphs) {
ImmutableGraphPtr gr = std::dynamic_pointer_cast<ImmutableGraph>(g);
CHECK(gr) << "All the input graphs should be immutable graphs.";
// TODO(minjie): why in csr?
const CSRPtr g_csrptr = gr->GetInCSR();
const int64_t g_num_nodes = g_csrptr->NumVertices();
const int64_t g_num_edges = g_csrptr->NumEdges();
dgl_id_t* g_indptr = static_cast<dgl_id_t*>(g_csrptr->indptr()->data);
dgl_id_t* g_indices = static_cast<dgl_id_t*>(g_csrptr->indices()->data);
dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(g_csrptr->edge_ids()->data);
for (dgl_id_t i = 1; i < g_num_nodes + 1; ++i) {
indptr[cum_num_nodes + i] = g_indptr[i] + cum_num_edges;
}
for (dgl_id_t i = 0; i < g_num_edges; ++i) {
indices[cum_num_edges + i] = g_indices[i] + cum_num_nodes;
}
for (dgl_id_t i = 0; i < g_num_edges; ++i) {
edge_ids[cum_num_edges + i] = g_edge_ids[i] + cum_num_edges;
}
cum_num_nodes += g_num_nodes;
cum_num_edges += g_num_edges;
}
cum_num_nodes += g_num_nodes;
cum_num_edges += g_num_edges;
}
CSRPtr batched_csr_ptr = CSRPtr(new CSR(indptr_arr, indices_arr, edge_ids_arr));
return ImmutableGraph(batched_csr_ptr, nullptr);
return ImmutableGraph::CreateFromCSR(indptr_arr, indices_arr, edge_ids_arr, "in");
}
}
std::vector<ImmutableGraph> GraphOp::DisjointPartitionByNum(const ImmutableGraph *graph,
int64_t num) {
std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(GraphPtr graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *sizes_data = static_cast<int64_t *>(sizes->data);
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGraph *batched_graph,
IdArray sizes) {
// TODO(minjie): use array views to speedup this operation
std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
GraphPtr batched_graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t *sizes_data = static_cast<int64_t *>(sizes->data);
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.reserve(len + 1);
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], batched_graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
std::vector<ImmutableGraph> rst;
CSRPtr in_csr_ptr = batched_graph->GetInCSR();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);
dgl_id_t cum_sum_edges = 0;
for (int64_t i = 0; i < len; ++i) {
const int64_t start_pos = cumsum[i];
const int64_t end_pos = cumsum[i + 1];
const int64_t g_num_nodes = sizes_data[i];
const int64_t g_num_edges = indptr[end_pos] - indptr[start_pos];
IdArray indptr_arr = aten::NewIdArray(g_num_nodes + 1);
IdArray indices_arr = aten::NewIdArray(g_num_edges);
IdArray edge_ids_arr = aten::NewIdArray(g_num_edges);
dgl_id_t* g_indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* g_indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
const dgl_id_t idoff = indptr[start_pos];
g_indptr[0] = 0;
for (int l = start_pos + 1; l < end_pos + 1; ++l) {
g_indptr[l - start_pos] = indptr[l] - indptr[start_pos];
}
for (int j = indptr[start_pos]; j < indptr[end_pos]; ++j) {
g_indices[j - idoff] = indices[j] - cumsum[i];
std::vector<GraphPtr> rst;
if (IsMutable(batched_graph)) {
// Input is a mutable graph. Partition it into several mutable graphs.
MutableGraphPtr graph = std::dynamic_pointer_cast<Graph>(batched_graph);
dgl_id_t node_offset = 0, edge_offset = 0;
for (int64_t i = 0; i < len; ++i) {
MutableGraphPtr mg = Graph::Create();
// TODO(minjie): quite ugly to expose internal members
// copy adj
mg->adjlist_.insert(mg->adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
mg->reverse_adjlist_.insert(mg->reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : mg->adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : mg->reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
mg->all_edges_src_.reserve(num_edges);
mg->all_edges_dst_.reserve(num_edges);
mg->num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
mg->all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
mg->all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// push to rst
rst.push_back(mg);
// update offset
CHECK_EQ(rst[i]->NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i]->NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
} else {
// Input is an immutable graph. Partition it into several multiple graphs.
ImmutableGraphPtr graph = std::dynamic_pointer_cast<ImmutableGraph>(batched_graph);
// TODO(minjie): why in csr?
CSRPtr in_csr_ptr = graph->GetInCSR();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);
dgl_id_t cum_sum_edges = 0;
for (int64_t i = 0; i < len; ++i) {
const int64_t start_pos = cumsum[i];
const int64_t end_pos = cumsum[i + 1];
const int64_t g_num_nodes = sizes_data[i];
const int64_t g_num_edges = indptr[end_pos] - indptr[start_pos];
IdArray indptr_arr = aten::NewIdArray(g_num_nodes + 1);
IdArray indices_arr = aten::NewIdArray(g_num_edges);
IdArray edge_ids_arr = aten::NewIdArray(g_num_edges);
dgl_id_t* g_indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* g_indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
for (int k = indptr[start_pos]; k < indptr[end_pos]; ++k) {
g_edge_ids[k - idoff] = edge_ids[k] - cum_sum_edges;
}
const dgl_id_t idoff = indptr[start_pos];
g_indptr[0] = 0;
for (int l = start_pos + 1; l < end_pos + 1; ++l) {
g_indptr[l - start_pos] = indptr[l] - indptr[start_pos];
}
for (int j = indptr[start_pos]; j < indptr[end_pos]; ++j) {
g_indices[j - idoff] = indices[j] - cumsum[i];
}
for (int k = indptr[start_pos]; k < indptr[end_pos]; ++k) {
g_edge_ids[k - idoff] = edge_ids[k] - cum_sum_edges;
}
cum_sum_edges += g_num_edges;
CSRPtr g_in_csr_ptr = CSRPtr(new CSR(indptr_arr, indices_arr, edge_ids_arr));
rst.emplace_back(g_in_csr_ptr, nullptr);
cum_sum_edges += g_num_edges;
rst.push_back(ImmutableGraph::CreateFromCSR(
indptr_arr, indices_arr, edge_ids_arr, "in"));
}
}
return rst;
}
......@@ -297,7 +310,7 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) {
return rst;
}
ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) {
GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {
std::vector<dgl_id_t> indptr(graph->NumVertices() + 1), indices;
indptr[0] = 0;
for (dgl_id_t src = 0; src < graph->NumVertices(); ++src) {
......@@ -312,10 +325,10 @@ ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) {
}
CSRPtr csr(new CSR(graph->NumVertices(), indices.size(),
indptr.begin(), indices.begin(), RangeIter(0), false));
return ImmutableGraph(csr);
return std::make_shared<ImmutableGraph>(csr);
}
Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
GraphPtr GraphOp::ToBidirectedMutableGraph(GraphPtr g) {
std::unordered_map<int, std::unordered_map<int, int>> n_e;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (const dgl_id_t v : g->SuccVec(u)) {
......@@ -323,8 +336,8 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
}
}
Graph bg;
bg.AddVertices(g->NumVertices());
GraphPtr bg = Graph::Create();
bg->AddVertices(g->NumVertices());
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (dgl_id_t v = u; v < g->NumVertices(); ++v) {
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
......@@ -333,13 +346,13 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);
std::fill(us_data, us_data + new_n_e, u);
if (u == v) {
bg.AddEdges(us, us);
bg->AddEdges(us, us);
} else {
IdArray vs = aten::NewIdArray(new_n_e);
dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);
std::fill(vs_data, vs_data + new_n_e, v);
bg.AddEdges(us, vs);
bg.AddEdges(vs, us);
bg->AddEdges(us, vs);
bg->AddEdges(vs, us);
}
}
}
......@@ -347,7 +360,7 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
return bg;
}
ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
std::unordered_map<int, std::unordered_map<int, int>> n_e;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (const dgl_id_t v : g->SuccVec(u)) {
......@@ -382,8 +395,80 @@ ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
IdArray srcs_array = aten::VecToIdArray(srcs);
IdArray dsts_array = aten::VecToIdArray(dsts);
COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph()));
return ImmutableGraph(coo);
return ImmutableGraph::CreateFromCOO(
g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph());
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0];
std::vector<GraphPtr> ptrs(graphs.size());
for (size_t i = 0; i < graphs.size(); ++i) {
ptrs[i] = graphs[i].sptr();
}
*rv = GraphOp::DisjointUnion(ptrs);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int64_t num = args[1];
const auto& ret = GraphOp::DisjointPartitionByNum(g.sptr(), num);
List<GraphRef> ret_list;
for (GraphPtr gp : ret) {
ret_list.push_back(GraphRef(gp));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray sizes = args[1];
const auto& ret = GraphOp::DisjointPartitionBySizes(g.sptr(), sizes);
List<GraphRef> ret_list;
for (GraphPtr gp : ret) {
ret_list.push_back(GraphRef(gp));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
bool backtracking = args[1];
*rv = GraphOp::LineGraph(g.sptr(), backtracking);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLToImmutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = ImmutableGraph::ToImmutable(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToSimpleGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToBidirectedMutableGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToBidirectedImmutableGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray parent_vids = args[0];
const IdArray query = args[1];
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
});
} // namespace dgl
......@@ -4,6 +4,7 @@
* \brief DGL immutable graph index implementation
*/
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <string.h>
#include <bitset>
......@@ -12,8 +13,14 @@
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace {
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
#ifndef _WIN32
......@@ -124,22 +131,22 @@ bool CSR::IsMultigraph() const {
});
}
CSR::EdgeArray CSR::OutEdges(dgl_id_t vid) const {
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return CSR::EdgeArray{ret_src, ret_dst, ret_eid};
return EdgeArray{ret_src, ret_dst, ret_eid};
}
CSR::EdgeArray CSR::OutEdges(IdArray vids) const {
EdgeArray CSR::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so
// we need to recover it using an index select.
auto row = aten::IndexSelect(vids, coosubmat.row);
return CSR::EdgeArray{row, coosubmat.col, coosubmat.data};
return EdgeArray{row, coosubmat.col, coosubmat.data};
}
DegreeArray CSR::OutDegrees(IdArray vids) const {
......@@ -171,17 +178,17 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
return aten::CSRGetData(adj_, src, dst);
}
CSR::EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
return CSR::EdgeArray{arrs[0], arrs[1], arrs[2]};
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}
CSR::EdgeArray CSR::Edges(const std::string &order) const {
EdgeArray CSR::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\".";
const auto& coo = aten::CSRToCOO(adj_, false);
return CSR::EdgeArray{coo.row, coo.col, coo.data};
return EdgeArray{coo.row, coo.col, coo.data};
}
Subgraph CSR::VertexSubgraph(IdArray vids) const {
......@@ -289,14 +296,14 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}
COO::EdgeArray COO::FindEdges(IdArray eids) const {
EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
eids};
}
COO::EdgeArray COO::Edges(const std::string &order) const {
EdgeArray COO::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \""
<< order << "\".";
......@@ -411,7 +418,7 @@ COOPtr ImmutableGraph::GetCOO() const {
return coo_;
}
ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const {
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
if (order.empty()) {
// arbitrary order
if (in_csr_) {
......@@ -467,53 +474,177 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
}
}
ImmutableGraph ImmutableGraph::ToImmutable(const GraphInterface* graph) {
const ImmutableGraph* ig = dynamic_cast<const ImmutableGraph*>(graph);
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst) {
COOPtr coo(new COO(num_vertices, src, dst));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
COOPtr coo(new COO(num_vertices, src, dst, multigraph));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
if (ig) {
return *ig;
return ig;
} else {
const auto& adj = graph->GetAdj(true, "csr");
CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
return ImmutableGraph(nullptr, csr);
return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
}
}
ImmutableGraph ImmutableGraph::CopyTo(const DLContext& ctx) const {
if (ctx == Context()) {
return *this;
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
// TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyTo(ctx)));
return ImmutableGraph(new_incsr, new_outcsr);
CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
}
ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir,
const std::string &name) const {
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
const std::string &edge_dir, const std::string &name) {
CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, edge_dir);
if (edge_dir == std::string("in"))
new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name)));
new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
else if (edge_dir == std::string("out"))
new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraph(new_incsr, new_outcsr, name);
new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
}
ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) {
return g;
} else {
// TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->AsNumBits(bits)));
return ImmutableGraph(new_incsr, new_outcsr);
CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
}
}
ImmutableGraphPtr ImmutableGraph::Reverse() const {
if (coo_) {
return ImmutableGraphPtr(new ImmutableGraph(
out_csr_, in_csr_, coo_->Transpose()));
} else {
return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
}
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const int device_type = args[1];
const int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyTo(ig, ctx);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
std::string edge_dir = args[1];
std::string name = args[2];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int bits = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::AsNumBits(ig, bits);
});
} // namespace dgl
......@@ -4,6 +4,8 @@
* \brief DGL networking related APIs
*/
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include "./network.h"
#include "./network/communicator.h"
#include "./network/socket_communicator.h"
......@@ -11,11 +13,7 @@
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace network {
......@@ -84,12 +82,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
GraphHandle ghandle = args[2];
// TODO(minjie): could simply use NodeFlow nf = args[2];
GraphRef g = args[2];
const IdArray node_mapping = args[3];
const IdArray edge_mapping = args[4];
const IdArray layer_offsets = args[5];
const IdArray flow_offsets = args[6];
ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle);
auto ptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ptr) << "only immutable graph is allowed in send/recv";
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR();
// Write control message
......@@ -159,7 +159,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer;
if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow();
NodeFlow nf = NodeFlow::Create();
CSRPtr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
......@@ -169,9 +169,9 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
List<NodeFlow> subgs;
subgs.push_back(nf);
*rv = subgs;
} else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL;
} else {
......
......@@ -5,9 +5,10 @@
*/
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/nodeflow.h>
#include <string.h>
#include <string>
#include "../c_api_common.h"
......@@ -78,15 +79,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
std::string format = args[1];
int64_t layer0_size = args[2];
int64_t start = args[3];
int64_t end = args[4];
const bool remap = args[5];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const ImmutableGraph* gptr = dynamic_cast<const ImmutableGraph*>(ptr);
auto res = GetNodeFlowSlice(*gptr, format, layer0_size, start, end, remap);
auto ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
*rv = ConvertNDArrayVectorToPackedFunc(res);
});
......
......@@ -7,6 +7,7 @@
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <algorithm>
#include <cstdlib>
#include <cmath>
......@@ -14,11 +15,7 @@
#include <functional>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
......@@ -218,43 +215,40 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
const IdArray seeds = args[1];
const int num_traces = args[2];
const int num_hops = args[3];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
*rv = RandomWalk(ptr, seeds, num_traces, num_hops);
*rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops);
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
const IdArray seeds = args[1];
const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc(
RandomWalkWithRestart(gptr, seeds, restart_prob, visit_threshold_per_seed,
RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
const IdArray seeds = args[1];
const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc(
BipartiteSingleSidedRandomWalkWithRestart(
gptr, seeds, restart_prob, visit_threshold_per_seed,
g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
......
......@@ -3,10 +3,10 @@
* \file graph/sampler.cc
* \brief DGL sampler implementation
*/
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dmlc/omp.h>
#include <algorithm>
#include <cstdlib>
......@@ -14,11 +14,7 @@
#include <numeric>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
......@@ -246,17 +242,17 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type,
int64_t num_edges, int num_hops, bool is_multigraph) {
NodeFlow nf;
NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size();
nf.node_mapping = aten::NewIdArray(num_vertices);
nf.edge_mapping = aten::NewIdArray(num_edges);
nf.layer_offsets = aten::NewIdArray(num_hops + 1);
nf.flow_offsets = aten::NewIdArray(num_hops);
nf->node_mapping = aten::NewIdArray(num_vertices);
nf->edge_mapping = aten::NewIdArray(num_edges);
nf->layer_offsets = aten::NewIdArray(num_hops + 1);
nf->flow_offsets = aten::NewIdArray(num_hops);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data);
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf.flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf.edge_mapping->data);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf->node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf->layer_offsets->data);
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data);
// Construct sub_csr_graph
// TODO(minjie): is nodeflow a multigraph?
......@@ -364,9 +360,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::iota(eid_out, eid_out + num_edges, 0);
if (edge_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr));
nf->graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr));
} else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr));
nf->graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr));
}
return nf;
......@@ -491,47 +487,34 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
GraphInterface* gptr = nflow->graph->Reset();
*rv = gptr;
NodeFlow nflow = args[0];
*rv = nflow->graph;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetNodeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->node_mapping;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetEdgeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->edge_mapping;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetLayerOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->layer_offsets;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->flow_offsets;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowFree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
delete nflow;
});
NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const std::string &edge_type,
......@@ -702,21 +685,21 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indptr.back(), sub_indices.size());
CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf;
NodeFlow nf = NodeFlow::Create();
auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr),
aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
nf->graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
} else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
nf->graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
}
nf.node_mapping = aten::VecToIdArray(node_mapping);
nf.edge_mapping = aten::VecToIdArray(edge_mapping);
nf.layer_offsets = aten::VecToIdArray(layer_offsets);
nf.flow_offsets = aten::VecToIdArray(flow_offsets);
nf->node_mapping = aten::VecToIdArray(node_mapping);
nf->edge_mapping = aten::VecToIdArray(edge_mapping);
nf->layer_offsets = aten::VecToIdArray(layer_offsets);
nf->flow_offsets = aten::VecToIdArray(flow_offsets);
return nf;
}
......@@ -736,7 +719,7 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
const GraphHandle ghdl = args[0];
GraphRef g = args[0];
const IdArray seed_nodes = args[1];
const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3];
......@@ -746,8 +729,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
const std::string neigh_type = args[7];
const bool add_self_loop = args[8];
// process args
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghdl);
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
......@@ -757,7 +739,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow*> nflows(num_workers);
std::vector<NodeFlow> nflows(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
// create per-worker seed nodes.
......@@ -767,17 +749,16 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin());
nflows[i] = new NodeFlow();
*nflows[i] = SamplerOp::NeighborUniformSample(
gptr, worker_seeds, neigh_type, num_hops, expand_factor, add_self_loop);
nflows[i] = SamplerOp::NeighborUniformSample(
gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor, add_self_loop);
}
*rv = WrapVectorReturn(nflows);
*rv = List<NodeFlow>(nflows);
});
DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
const GraphHandle ghdl = args[0];
GraphRef g = args[0];
const IdArray seed_nodes = args[1];
const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3];
......@@ -785,8 +766,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const IdArray layer_sizes = args[5];
const std::string neigh_type = args[6];
// process args
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghdl);
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
......@@ -796,7 +776,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow*> nflows(num_workers);
std::vector<NodeFlow> nflows(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
// create per-worker seed nodes.
......@@ -806,11 +786,10 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin());
nflows[i] = new NodeFlow();
*nflows[i] = SamplerOp::LayerUniformSample(
gptr, worker_seeds, neigh_type, layer_sizes);
nflows[i] = SamplerOp::LayerUniformSample(
gptr.get(), worker_seeds, neigh_type, layer_sizes);
}
*rv = WrapVectorReturn(nflows);
*rv = List<NodeFlow>(nflows);
});
} // namespace dgl
......@@ -3,16 +3,13 @@
* \file graph/traversal.cc
* \brief Graph traversal implementation
*/
#include <dgl/packed_func_ext.h>
#include <algorithm>
#include <queue>
#include "./traversal.h"
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace traverse {
......@@ -115,7 +112,7 @@ struct Frontiers {
std::vector<int64_t> sections;
};
Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
......@@ -131,17 +128,16 @@ Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*gptr, src, reversed);
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front;
// NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes;
......@@ -162,17 +158,16 @@ Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*gptr, src, reversed);
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
});
Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) {
Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
......@@ -188,10 +183,9 @@ Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*gptr, reversed);
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
......@@ -200,8 +194,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
......@@ -210,7 +203,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*gptr, src_data[i], reversed, false, false, visit);
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
}
IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges);
......@@ -219,8 +212,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
......@@ -243,7 +235,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
tags[i].push_back(tag);
}
};
DFSLabeledEdges(*gptr, src_data[i], reversed,
DFSLabeledEdges(*g.sptr(), src_data[i], reversed,
has_reverse_edge, has_nontree_edge, visit);
}
......
......@@ -11,7 +11,7 @@
#ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_
#include <dgl/graph.h>
#include <dgl/graph_interface.h>
#include <stack>
#include <tuple>
#include <vector>
......@@ -45,7 +45,7 @@ namespace traverse {
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const Graph& graph,
void BFSNodes(const GraphInterface& graph,
IdArray source,
bool reversed,
Queue* queue,
......@@ -63,7 +63,7 @@ void BFSNodes(const Graph& graph,
}
make_frontier();
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec;
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
......@@ -109,7 +109,7 @@ void BFSNodes(const Graph& graph,
* \param make_frontier The function to indicate that a new frontier can be made;
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(const Graph& graph,
void BFSEdges(const GraphInterface& graph,
IdArray source,
bool reversed,
Queue* queue,
......@@ -126,7 +126,7 @@ void BFSEdges(const Graph& graph,
}
make_frontier();
const auto neighbor_iter = reversed? &Graph::InEdgeVec : &Graph::OutEdgeVec;
const auto neighbor_iter = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
......@@ -171,13 +171,13 @@ void BFSEdges(const Graph& graph,
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const Graph& graph,
void TopologicalNodes(const GraphInterface& graph,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const auto get_degree = reversed? &Graph::OutDegree : &Graph::InDegree;
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec;
const auto get_degree = reversed? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
uint64_t num_visited_nodes = 0;
std::vector<uint64_t> degrees(graph.NumVertices(), 0);
for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {
......@@ -237,14 +237,14 @@ enum DFSEdgeTag {
* tag will be given as the arguments.
*/
template<typename VisitFn>
void DFSLabeledEdges(const Graph& graph,
void DFSLabeledEdges(const GraphInterface& graph,
dgl_id_t source,
bool reversed,
bool has_reverse_edge,
bool has_nontree_edge,
VisitFn visit) {
const auto succ = reversed? &Graph::PredVec : &Graph::SuccVec;
const auto out_edge = reversed? &Graph::InEdgeVec : &Graph::OutEdgeVec;
const auto succ = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
const auto out_edge = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
if ((graph.*succ)(source).size() == 0) {
// no out-going edges from the source node
......
......@@ -3,17 +3,14 @@
* \file kernel/binary_reduce.cc
* \brief Binary reduce C APIs and definitions.
*/
#include <dgl/packed_func_ext.h>
#include "./binary_reduce.h"
#include "./common.h"
#include "./binary_reduce_impl_decl.h"
#include "./utils.h"
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace kernel {
......@@ -273,7 +270,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphHandle ghdl = args[2];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_data = args[5];
......@@ -283,10 +280,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BinaryOpReduce(reducer, op, igptr,
BinaryOpReduce(reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping);
......@@ -346,7 +342,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphHandle ghdl = args[2];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_mapping = args[5];
......@@ -358,11 +354,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardLhsBinaryOpReduce(
reducer, op, igptr,
reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
......@@ -422,7 +417,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphHandle ghdl = args[2];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_mapping = args[5];
......@@ -434,11 +429,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardRhsBinaryOpReduce(
reducer, op, igptr,
reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
......@@ -469,17 +463,16 @@ void CopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
GraphHandle ghdl = args[1];
GraphRef g = args[1];
int target = args[2];
NDArray in_data = args[3];
NDArray out_data = args[4];
NDArray in_mapping = args[5];
NDArray out_mapping = args[6];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
CopyReduce(reducer, igptr,
CopyReduce(reducer, igptr.get(),
static_cast<binary_op::Target>(target),
in_data, out_data,
in_mapping, out_mapping);
......@@ -518,7 +511,7 @@ void BackwardCopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
GraphHandle ghdl = args[1];
GraphRef g = args[1];
int target = args[2];
NDArray in_data = args[3];
NDArray out_data = args[4];
......@@ -527,11 +520,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7];
NDArray out_mapping = args[8];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardCopyReduce(
reducer, igptr, static_cast<binary_op::Target>(target),
reducer, igptr.get(), static_cast<binary_op::Target>(target),
in_mapping, out_mapping,
in_data, out_data, grad_out_data,
grad_in_data);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment