Unverified Commit 67dc1197 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Use object system for all CAPIs (#716)

* WIP: using object system for graph

* c++ side refactoring done; compiled

* remove stale apis

* fix bug in DGLGraphCreate; passed test_graph.py

* fix bug in python modify; passed utest for pytorch/cpu

* fix lint

* address comments
parent b0d9e7aa
...@@ -25,25 +25,4 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) { ...@@ -25,25 +25,4 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
return PackedFunc(body); return PackedFunc(body);
} }
DGL_REGISTER_GLOBAL("_GetVectorWrapperSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const CAPIVectorWrapper* wrapper = static_cast<const CAPIVectorWrapper*>(ptr);
*rv = static_cast<int64_t>(wrapper->pointers.size());
});
DGL_REGISTER_GLOBAL("_GetVectorWrapperData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
*rv = static_cast<void*>(wrapper->pointers.data());
});
DGL_REGISTER_GLOBAL("_FreeVectorWrapper")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
delete wrapper;
});
} // namespace dgl } // namespace dgl
...@@ -34,9 +34,6 @@ inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) { ...@@ -34,9 +34,6 @@ inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
namespace dgl { namespace dgl {
// Graph handler type
typedef void* GraphHandle;
// Communicator handler type // Communicator handler type
typedef void* CommunicatorHandle; typedef void* CommunicatorHandle;
...@@ -73,29 +70,6 @@ dgl::runtime::NDArray CopyVectorToNDArray( ...@@ -73,29 +70,6 @@ dgl::runtime::NDArray CopyVectorToNDArray(
return a; return a;
} }
/* A structure used to return a vector of void* pointers. */
struct CAPIVectorWrapper {
// The pointer vector.
std::vector<void*> pointers;
};
/*!
* \brief A helper function used to return vector of pointers from C to frontend.
*
* Note that the function will move the given vector memory into the returned
* wrapper object.
*
* \param vec The given pointer vectors.
* \return A wrapper object containing the given data.
*/
template<typename PType>
CAPIVectorWrapper* WrapVectorReturn(std::vector<PType*> vec) {
CAPIVectorWrapper* wrapper = new CAPIVectorWrapper;
wrapper->pointers.reserve(vec.size());
wrapper->pointers.insert(wrapper->pointers.end(), vec.begin(), vec.end());
return wrapper;
}
} // namespace dgl } // namespace dgl
#endif // DGL_C_API_COMMON_H_ #endif // DGL_C_API_COMMON_H_
...@@ -200,7 +200,7 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { ...@@ -200,7 +200,7 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
} }
// O(E*k) pretty slow // O(E*k) pretty slow
Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
...@@ -246,7 +246,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -246,7 +246,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
return EdgeArray{rst_src, rst_dst, rst_eid}; return EdgeArray{rst_src, rst_dst, rst_eid};
} }
Graph::EdgeArray Graph::FindEdges(IdArray eids) const { EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0]; int64_t len = eids->shape[0];
...@@ -272,7 +272,7 @@ Graph::EdgeArray Graph::FindEdges(IdArray eids) const { ...@@ -272,7 +272,7 @@ Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
} }
// O(E) // O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const { EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size(); const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -290,7 +290,7 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const { ...@@ -290,7 +290,7 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
} }
// O(E) // O(E)
Graph::EdgeArray Graph::InEdges(IdArray vids) const { EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
...@@ -318,7 +318,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const { ...@@ -318,7 +318,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
} }
// O(E) // O(E)
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const { EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size(); const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -336,7 +336,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const { ...@@ -336,7 +336,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
} }
// O(E) // O(E)
Graph::EdgeArray Graph::OutEdges(IdArray vids) const { EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
...@@ -364,7 +364,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const { ...@@ -364,7 +364,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
} }
// O(E*log(E)) if sort is required; otherwise, O(E) // O(E*log(E)) if sort is required; otherwise, O(E)
Graph::EdgeArray Graph::Edges(const std::string &order) const { EdgeArray Graph::Edges(const std::string &order) const {
const int64_t len = num_edges_; const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -585,9 +585,4 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -585,9 +585,4 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
} }
} }
GraphPtr Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return nullptr;
}
} // namespace dgl } // namespace dgl
This diff is collapsed.
...@@ -5,9 +5,13 @@ ...@@ -5,9 +5,13 @@
*/ */
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <algorithm> #include <algorithm>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace { namespace {
// generate consecutive dgl ids // generate consecutive dgl ids
...@@ -38,103 +42,61 @@ class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> { ...@@ -38,103 +42,61 @@ class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {
private: private:
dgl_id_t cur_; dgl_id_t cur_;
}; };
bool IsMutable(GraphPtr g) {
MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(g);
return mg != nullptr;
}
} // namespace } // namespace
Graph GraphOp::LineGraph(const Graph* g, bool backtracking) { GraphPtr GraphOp::Reverse(GraphPtr g) {
Graph lg; ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g);
lg.AddVertices(g->NumEdges()); CHECK(ig) << "Reverse is only supported on immutable graph";
for (size_t i = 0; i < g->all_edges_src_.size(); ++i) { return ig->Reverse();
const auto u = g->all_edges_src_[i]; }
const auto v = g->all_edges_dst_[i];
for (size_t j = 0; j < g->adjlist_[v].succ.size(); ++j) { GraphPtr GraphOp::LineGraph(GraphPtr g, bool backtracking) {
if (backtracking || (!backtracking && g->adjlist_[v].succ[j] != u)) { MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(g);
lg.AddEdge(i, g->adjlist_[v].edge_id[j]); CHECK(mg) << "Line graph transformation is only supported on mutable graph";
MutableGraphPtr lg = Graph::Create();
lg->AddVertices(g->NumEdges());
for (size_t i = 0; i < mg->all_edges_src_.size(); ++i) {
const auto u = mg->all_edges_src_[i];
const auto v = mg->all_edges_dst_[i];
for (size_t j = 0; j < mg->adjlist_[v].succ.size(); ++j) {
if (backtracking || (!backtracking && mg->adjlist_[v].succ[j] != u)) {
lg->AddEdge(i, mg->adjlist_[v].edge_id[j]);
} }
} }
} }
return lg; return lg;
} }
Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) { GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
Graph rst; CHECK_GT(graphs.size(), 0) << "Input graph list is empty";
if (IsMutable(graphs[0])) {
// Disjointly union of a list of mutable graph inputs. The result is
// also a mutable graph.
MutableGraphPtr rst = Graph::Create();
uint64_t cumsum = 0; uint64_t cumsum = 0;
for (const Graph* gr : graphs) { for (GraphPtr gr : graphs) {
rst.AddVertices(gr->NumVertices()); MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(gr);
CHECK(mg) << "All the input graphs should be mutable graphs.";
rst->AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) { for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
rst.AddEdge(gr->all_edges_src_[i] + cumsum, gr->all_edges_dst_[i] + cumsum); // TODO(minjie): quite ugly to expose internal members
rst->AddEdge(mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum);
} }
cumsum += gr->NumVertices(); cumsum += gr->NumVertices();
} }
return rst; return rst;
} } else {
// Disjointly union of a list of immutable graph inputs. The result is
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) { // also an immutable graph.
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
return rst;
}
ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs) {
int64_t num_nodes = 0; int64_t num_nodes = 0;
int64_t num_edges = 0; int64_t num_edges = 0;
for (const ImmutableGraph *gr : graphs) { for (auto gr : graphs) {
num_nodes += gr->NumVertices(); num_nodes += gr->NumVertices();
num_edges += gr->NumEdges(); num_edges += gr->NumEdges();
} }
...@@ -148,7 +110,10 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs ...@@ -148,7 +110,10 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs
indptr[0] = 0; indptr[0] = 0;
dgl_id_t cum_num_nodes = 0; dgl_id_t cum_num_nodes = 0;
dgl_id_t cum_num_edges = 0; dgl_id_t cum_num_edges = 0;
for (const ImmutableGraph *gr : graphs) { for (auto g : graphs) {
ImmutableGraphPtr gr = std::dynamic_pointer_cast<ImmutableGraph>(g);
CHECK(gr) << "All the input graphs should be immutable graphs.";
// TODO(minjie): why in csr?
const CSRPtr g_csrptr = gr->GetInCSR(); const CSRPtr g_csrptr = gr->GetInCSR();
const int64_t g_num_nodes = g_csrptr->NumVertices(); const int64_t g_num_nodes = g_csrptr->NumVertices();
const int64_t g_num_edges = g_csrptr->NumEdges(); const int64_t g_num_edges = g_csrptr->NumEdges();
...@@ -169,35 +134,82 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs ...@@ -169,35 +134,82 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs
cum_num_edges += g_num_edges; cum_num_edges += g_num_edges;
} }
CSRPtr batched_csr_ptr = CSRPtr(new CSR(indptr_arr, indices_arr, edge_ids_arr)); return ImmutableGraph::CreateFromCSR(indptr_arr, indices_arr, edge_ids_arr, "in");
return ImmutableGraph(batched_csr_ptr, nullptr); }
} }
std::vector<ImmutableGraph> GraphOp::DisjointPartitionByNum(const ImmutableGraph *graph, std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(GraphPtr graph, int64_t num) {
int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0) CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes."; << "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *sizes_data = static_cast<int64_t *>(sizes->data); int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num); std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes); return DisjointPartitionBySizes(graph, sizes);
} }
std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGraph *batched_graph, std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
IdArray sizes) { GraphPtr batched_graph, IdArray sizes) {
// TODO(minjie): use array views to speedup this operation
const int64_t len = sizes->shape[0]; const int64_t len = sizes->shape[0];
const int64_t *sizes_data = static_cast<int64_t *>(sizes->data); const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum; std::vector<int64_t> cumsum;
cumsum.reserve(len + 1);
cumsum.push_back(0); cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]); cumsum.push_back(cumsum[i] + sizes_data[i]);
} }
CHECK_EQ(cumsum[len], batched_graph->NumVertices()) CHECK_EQ(cumsum[len], batched_graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes."; << "Sum of the given sizes must equal to the number of nodes.";
std::vector<ImmutableGraph> rst;
CSRPtr in_csr_ptr = batched_graph->GetInCSR(); std::vector<GraphPtr> rst;
if (IsMutable(batched_graph)) {
// Input is a mutable graph. Partition it into several mutable graphs.
MutableGraphPtr graph = std::dynamic_pointer_cast<Graph>(batched_graph);
dgl_id_t node_offset = 0, edge_offset = 0;
for (int64_t i = 0; i < len; ++i) {
MutableGraphPtr mg = Graph::Create();
// TODO(minjie): quite ugly to expose internal members
// copy adj
mg->adjlist_.insert(mg->adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
mg->reverse_adjlist_.insert(mg->reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : mg->adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : mg->reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
mg->all_edges_src_.reserve(num_edges);
mg->all_edges_dst_.reserve(num_edges);
mg->num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
mg->all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
mg->all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// push to rst
rst.push_back(mg);
// update offset
CHECK_EQ(rst[i]->NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i]->NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
} else {
// Input is an immutable graph. Partition it into several multiple graphs.
ImmutableGraphPtr graph = std::dynamic_pointer_cast<ImmutableGraph>(batched_graph);
// TODO(minjie): why in csr?
CSRPtr in_csr_ptr = graph->GetInCSR();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data); const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(in_csr_ptr->indices()->data); const dgl_id_t* indices = static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data); const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);
...@@ -229,8 +241,9 @@ std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGra ...@@ -229,8 +241,9 @@ std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGra
} }
cum_sum_edges += g_num_edges; cum_sum_edges += g_num_edges;
CSRPtr g_in_csr_ptr = CSRPtr(new CSR(indptr_arr, indices_arr, edge_ids_arr)); rst.push_back(ImmutableGraph::CreateFromCSR(
rst.emplace_back(g_in_csr_ptr, nullptr); indptr_arr, indices_arr, edge_ids_arr, "in"));
}
} }
return rst; return rst;
} }
...@@ -297,7 +310,7 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) { ...@@ -297,7 +310,7 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) {
return rst; return rst;
} }
ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) { GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {
std::vector<dgl_id_t> indptr(graph->NumVertices() + 1), indices; std::vector<dgl_id_t> indptr(graph->NumVertices() + 1), indices;
indptr[0] = 0; indptr[0] = 0;
for (dgl_id_t src = 0; src < graph->NumVertices(); ++src) { for (dgl_id_t src = 0; src < graph->NumVertices(); ++src) {
...@@ -312,10 +325,10 @@ ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) { ...@@ -312,10 +325,10 @@ ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) {
} }
CSRPtr csr(new CSR(graph->NumVertices(), indices.size(), CSRPtr csr(new CSR(graph->NumVertices(), indices.size(),
indptr.begin(), indices.begin(), RangeIter(0), false)); indptr.begin(), indices.begin(), RangeIter(0), false));
return ImmutableGraph(csr); return std::make_shared<ImmutableGraph>(csr);
} }
Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) { GraphPtr GraphOp::ToBidirectedMutableGraph(GraphPtr g) {
std::unordered_map<int, std::unordered_map<int, int>> n_e; std::unordered_map<int, std::unordered_map<int, int>> n_e;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) { for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (const dgl_id_t v : g->SuccVec(u)) { for (const dgl_id_t v : g->SuccVec(u)) {
...@@ -323,8 +336,8 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) { ...@@ -323,8 +336,8 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
} }
} }
Graph bg; GraphPtr bg = Graph::Create();
bg.AddVertices(g->NumVertices()); bg->AddVertices(g->NumVertices());
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) { for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (dgl_id_t v = u; v < g->NumVertices(); ++v) { for (dgl_id_t v = u; v < g->NumVertices(); ++v) {
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]); const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
...@@ -333,13 +346,13 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) { ...@@ -333,13 +346,13 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data); dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);
std::fill(us_data, us_data + new_n_e, u); std::fill(us_data, us_data + new_n_e, u);
if (u == v) { if (u == v) {
bg.AddEdges(us, us); bg->AddEdges(us, us);
} else { } else {
IdArray vs = aten::NewIdArray(new_n_e); IdArray vs = aten::NewIdArray(new_n_e);
dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data); dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);
std::fill(vs_data, vs_data + new_n_e, v); std::fill(vs_data, vs_data + new_n_e, v);
bg.AddEdges(us, vs); bg->AddEdges(us, vs);
bg.AddEdges(vs, us); bg->AddEdges(vs, us);
} }
} }
} }
...@@ -347,7 +360,7 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) { ...@@ -347,7 +360,7 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
return bg; return bg;
} }
ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) { GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
std::unordered_map<int, std::unordered_map<int, int>> n_e; std::unordered_map<int, std::unordered_map<int, int>> n_e;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) { for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (const dgl_id_t v : g->SuccVec(u)) { for (const dgl_id_t v : g->SuccVec(u)) {
...@@ -382,8 +395,80 @@ ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) { ...@@ -382,8 +395,80 @@ ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
IdArray srcs_array = aten::VecToIdArray(srcs); IdArray srcs_array = aten::VecToIdArray(srcs);
IdArray dsts_array = aten::VecToIdArray(dsts); IdArray dsts_array = aten::VecToIdArray(dsts);
COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph())); return ImmutableGraph::CreateFromCOO(
return ImmutableGraph(coo); g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph());
} }
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0];
std::vector<GraphPtr> ptrs(graphs.size());
for (size_t i = 0; i < graphs.size(); ++i) {
ptrs[i] = graphs[i].sptr();
}
*rv = GraphOp::DisjointUnion(ptrs);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int64_t num = args[1];
const auto& ret = GraphOp::DisjointPartitionByNum(g.sptr(), num);
List<GraphRef> ret_list;
for (GraphPtr gp : ret) {
ret_list.push_back(GraphRef(gp));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray sizes = args[1];
const auto& ret = GraphOp::DisjointPartitionBySizes(g.sptr(), sizes);
List<GraphRef> ret_list;
for (GraphPtr gp : ret) {
ret_list.push_back(GraphRef(gp));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
bool backtracking = args[1];
*rv = GraphOp::LineGraph(g.sptr(), backtracking);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLToImmutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = ImmutableGraph::ToImmutable(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToSimpleGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToBidirectedMutableGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToBidirectedImmutableGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray parent_vids = args[0];
const IdArray query = args[1];
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
});
} // namespace dgl } // namespace dgl
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief DGL immutable graph index implementation * \brief DGL immutable graph index implementation
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <string.h> #include <string.h>
#include <bitset> #include <bitset>
...@@ -12,8 +13,14 @@ ...@@ -12,8 +13,14 @@
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace { namespace {
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) { const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
#ifndef _WIN32 #ifndef _WIN32
...@@ -124,22 +131,22 @@ bool CSR::IsMultigraph() const { ...@@ -124,22 +131,22 @@ bool CSR::IsMultigraph() const {
}); });
} }
CSR::EdgeArray CSR::OutEdges(dgl_id_t vid) const { EdgeArray CSR::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid); IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
IdArray ret_eid = aten::CSRGetRowData(adj_, vid); IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx); IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return CSR::EdgeArray{ret_src, ret_dst, ret_eid}; return EdgeArray{ret_src, ret_dst, ret_eid};
} }
CSR::EdgeArray CSR::OutEdges(IdArray vids) const { EdgeArray CSR::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids); auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false); auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so // Note that the row id in the csr submat is relabled, so
// we need to recover it using an index select. // we need to recover it using an index select.
auto row = aten::IndexSelect(vids, coosubmat.row); auto row = aten::IndexSelect(vids, coosubmat.row);
return CSR::EdgeArray{row, coosubmat.col, coosubmat.data}; return EdgeArray{row, coosubmat.col, coosubmat.data};
} }
DegreeArray CSR::OutDegrees(IdArray vids) const { DegreeArray CSR::OutDegrees(IdArray vids) const {
...@@ -171,17 +178,17 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const { ...@@ -171,17 +178,17 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
return aten::CSRGetData(adj_, src, dst); return aten::CSRGetData(adj_, src, dst);
} }
CSR::EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const { EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids); const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
return CSR::EdgeArray{arrs[0], arrs[1], arrs[2]}; return EdgeArray{arrs[0], arrs[1], arrs[2]};
} }
CSR::EdgeArray CSR::Edges(const std::string &order) const { EdgeArray CSR::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("srcdst")) CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\"," << "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\"."; << " but got \"" << order << "\".";
const auto& coo = aten::CSRToCOO(adj_, false); const auto& coo = aten::CSRToCOO(adj_, false);
return CSR::EdgeArray{coo.row, coo.col, coo.data}; return EdgeArray{coo.row, coo.col, coo.data};
} }
Subgraph CSR::VertexSubgraph(IdArray vids) const { Subgraph CSR::VertexSubgraph(IdArray vids) const {
...@@ -289,14 +296,14 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const { ...@@ -289,14 +296,14 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
return std::pair<dgl_id_t, dgl_id_t>(src, dst); return std::pair<dgl_id_t, dgl_id_t>(src, dst);
} }
COO::EdgeArray COO::FindEdges(IdArray eids) const { EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids), return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids), aten::IndexSelect(adj_.col, eids),
eids}; eids};
} }
COO::EdgeArray COO::Edges(const std::string &order) const { EdgeArray COO::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("eid")) CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \"" << "COO only support Edges of order \"eid\", but got \""
<< order << "\"."; << order << "\".";
...@@ -411,7 +418,7 @@ COOPtr ImmutableGraph::GetCOO() const { ...@@ -411,7 +418,7 @@ COOPtr ImmutableGraph::GetCOO() const {
return coo_; return coo_;
} }
ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const { EdgeArray ImmutableGraph::Edges(const std::string &order) const {
if (order.empty()) { if (order.empty()) {
// arbitrary order // arbitrary order
if (in_csr_) { if (in_csr_) {
...@@ -467,53 +474,177 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f ...@@ -467,53 +474,177 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
} }
} }
ImmutableGraph ImmutableGraph::ToImmutable(const GraphInterface* graph) { ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const ImmutableGraph* ig = dynamic_cast<const ImmutableGraph*>(graph); IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst) {
COOPtr coo(new COO(num_vertices, src, dst));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
COOPtr coo(new COO(num_vertices, src, dst, multigraph));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
if (ig) { if (ig) {
return *ig; return ig;
} else { } else {
const auto& adj = graph->GetAdj(true, "csr"); const auto& adj = graph->GetAdj(true, "csr");
CSRPtr csr(new CSR(adj[0], adj[1], adj[2])); CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
return ImmutableGraph(nullptr, csr); return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
} }
} }
ImmutableGraph ImmutableGraph::CopyTo(const DLContext& ctx) const { ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
if (ctx == Context()) { if (ctx == g->Context()) {
return *this; return g;
} }
// TODO(minjie): since we don't have GPU implementation of COO<->CSR, // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR, // we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should // and then copy them to other context (usually GPU). This should
// be fixed later. // be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->CopyTo(ctx))); CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyTo(ctx))); CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
return ImmutableGraph(new_incsr, new_outcsr); return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
} }
ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir, ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
const std::string &name) const { const std::string &edge_dir, const std::string &name) {
CSRPtr new_incsr, new_outcsr; CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, edge_dir); std::string shared_mem_name = GetSharedMemName(name, edge_dir);
if (edge_dir == std::string("in")) if (edge_dir == std::string("in"))
new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name))); new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
else if (edge_dir == std::string("out")) else if (edge_dir == std::string("out"))
new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name))); new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraph(new_incsr, new_outcsr, name); return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
} }
ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const { ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
if (NumBits() == bits) { if (g->NumBits() == bits) {
return *this; return g;
} else { } else {
// TODO(minjie): since we don't have int32 operations, // TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR, // we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should // and then copy them to other context (usually GPU). This should
// be fixed later. // be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->AsNumBits(bits))); CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->AsNumBits(bits))); CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
return ImmutableGraph(new_incsr, new_outcsr); return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
} }
} }
ImmutableGraphPtr ImmutableGraph::Reverse() const {
if (coo_) {
return ImmutableGraphPtr(new ImmutableGraph(
out_csr_, in_csr_, coo_->Transpose()));
} else {
return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
}
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const int device_type = args[1];
const int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyTo(ig, ctx);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
std::string edge_dir = args[1];
std::string name = args[2];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int bits = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::AsNumBits(ig, bits);
});
} // namespace dgl } // namespace dgl
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* \brief DGL networking related APIs * \brief DGL networking related APIs
*/ */
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include "./network.h" #include "./network.h"
#include "./network/communicator.h" #include "./network/communicator.h"
#include "./network/socket_communicator.h" #include "./network/socket_communicator.h"
...@@ -11,11 +13,7 @@ ...@@ -11,11 +13,7 @@
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using namespace dgl::runtime;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace network { namespace network {
...@@ -84,12 +82,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") ...@@ -84,12 +82,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
int recv_id = args[1]; int recv_id = args[1];
GraphHandle ghandle = args[2]; // TODO(minjie): could simply use NodeFlow nf = args[2];
GraphRef g = args[2];
const IdArray node_mapping = args[3]; const IdArray node_mapping = args[3];
const IdArray edge_mapping = args[4]; const IdArray edge_mapping = args[4];
const IdArray layer_offsets = args[5]; const IdArray layer_offsets = args[5];
const IdArray flow_offsets = args[6]; const IdArray flow_offsets = args[6];
ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle); auto ptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ptr) << "only immutable graph is allowed in send/recv";
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR(); auto csr = ptr->GetInCSR();
// Write control message // Write control message
...@@ -159,7 +159,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph") ...@@ -159,7 +159,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
RecvData(receiver, buffer, kMaxBufferSize); RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer; int control = *buffer;
if (control == CONTROL_NODEFLOW) { if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow(); NodeFlow nf = NodeFlow::Create();
CSRPtr csr; CSRPtr csr;
// Deserialize nodeflow from recv_data_buffer // Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW), network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
...@@ -169,9 +169,9 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph") ...@@ -169,9 +169,9 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
&(nf->layer_offsets), &(nf->layer_offsets),
&(nf->flow_offsets)); &(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr)); nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
std::vector<NodeFlow*> subgs(1); List<NodeFlow> subgs;
subgs[0] = nf; subgs.push_back(nf);
*rv = WrapVectorReturn(subgs); *rv = subgs;
} else if (control == CONTROL_END_SIGNAL) { } else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL; *rv = CONTROL_END_SIGNAL;
} else { } else {
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
*/ */
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/nodeflow.h> #include <dgl/nodeflow.h>
#include <string.h> #include <string>
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -78,15 +79,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -78,15 +79,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockAdj") DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
std::string format = args[1]; std::string format = args[1];
int64_t layer0_size = args[2]; int64_t layer0_size = args[2];
int64_t start = args[3]; int64_t start = args[3];
int64_t end = args[4]; int64_t end = args[4];
const bool remap = args[5]; const bool remap = args[5];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle); auto ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
const ImmutableGraph* gptr = dynamic_cast<const ImmutableGraph*>(ptr); auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
auto res = GetNodeFlowSlice(*gptr, format, layer0_size, start, end, remap);
*rv = ConvertNDArrayVectorToPackedFunc(res); *rv = ConvertNDArrayVectorToPackedFunc(res);
}); });
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/sampler.h> #include <dgl/sampler.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath> #include <cmath>
...@@ -14,11 +15,7 @@ ...@@ -14,11 +15,7 @@
#include <functional> #include <functional>
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using namespace dgl::runtime;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
...@@ -218,43 +215,40 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart( ...@@ -218,43 +215,40 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk") DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const IdArray seeds = args[1]; const IdArray seeds = args[1];
const int num_traces = args[2]; const int num_traces = args[2];
const int num_hops = args[3]; const int num_hops = args[3];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
*rv = RandomWalk(ptr, seeds, num_traces, num_hops); *rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops);
}); });
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart") DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const IdArray seeds = args[1]; const IdArray seeds = args[1];
const double restart_prob = args[2]; const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3]; const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4]; const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5]; const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc( *rv = ConvertRandomWalkTracesToPackedFunc(
RandomWalkWithRestart(gptr, seeds, restart_prob, visit_threshold_per_seed, RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes)); max_visit_counts, max_frequent_visited_nodes));
}); });
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart") DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const IdArray seeds = args[1]; const IdArray seeds = args[1];
const double restart_prob = args[2]; const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3]; const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4]; const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5]; const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc( *rv = ConvertRandomWalkTracesToPackedFunc(
BipartiteSingleSidedRandomWalkWithRestart( BipartiteSingleSidedRandomWalkWithRestart(
gptr, seeds, restart_prob, visit_threshold_per_seed, g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes)); max_visit_counts, max_frequent_visited_nodes));
}); });
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
* \file graph/sampler.cc * \file graph/sampler.cc
* \brief DGL sampler implementation * \brief DGL sampler implementation
*/ */
#include <dgl/sampler.h> #include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
...@@ -14,11 +14,7 @@ ...@@ -14,11 +14,7 @@
#include <numeric> #include <numeric>
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using namespace dgl::runtime;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
...@@ -246,17 +242,17 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -246,17 +242,17 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<neighbor_info> *neigh_pos, std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type, const std::string &edge_type,
int64_t num_edges, int num_hops, bool is_multigraph) { int64_t num_edges, int num_hops, bool is_multigraph) {
NodeFlow nf; NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size(); uint64_t num_vertices = sub_vers->size();
nf.node_mapping = aten::NewIdArray(num_vertices); nf->node_mapping = aten::NewIdArray(num_vertices);
nf.edge_mapping = aten::NewIdArray(num_edges); nf->edge_mapping = aten::NewIdArray(num_edges);
nf.layer_offsets = aten::NewIdArray(num_hops + 1); nf->layer_offsets = aten::NewIdArray(num_hops + 1);
nf.flow_offsets = aten::NewIdArray(num_hops); nf->flow_offsets = aten::NewIdArray(num_hops);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data); dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf->node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data); dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf->layer_offsets->data);
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf.flow_offsets->data); dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf.edge_mapping->data); dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data);
// Construct sub_csr_graph // Construct sub_csr_graph
// TODO(minjie): is nodeflow a multigraph? // TODO(minjie): is nodeflow a multigraph?
...@@ -364,9 +360,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -364,9 +360,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::iota(eid_out, eid_out + num_edges, 0); std::iota(eid_out, eid_out + num_edges, 0);
if (edge_type == std::string("in")) { if (edge_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr)); nf->graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr));
} else { } else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr)); nf->graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr));
} }
return nf; return nf;
...@@ -491,47 +487,34 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -491,47 +487,34 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetGraph") DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0]; NodeFlow nflow = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr); *rv = nflow->graph;
GraphInterface* gptr = nflow->graph->Reset();
*rv = gptr;
}); });
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetNodeMapping") DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetNodeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0]; NodeFlow nflow = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
*rv = nflow->node_mapping; *rv = nflow->node_mapping;
}); });
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetEdgeMapping") DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetEdgeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0]; NodeFlow nflow = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
*rv = nflow->edge_mapping; *rv = nflow->edge_mapping;
}); });
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetLayerOffsets") DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetLayerOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0]; NodeFlow nflow = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
*rv = nflow->layer_offsets; *rv = nflow->layer_offsets;
}); });
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockOffsets") DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0]; NodeFlow nflow = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
*rv = nflow->flow_offsets; *rv = nflow->flow_offsets;
}); });
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowFree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
delete nflow;
});
NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph, NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds, const std::vector<dgl_id_t>& seeds,
const std::string &edge_type, const std::string &edge_type,
...@@ -702,21 +685,21 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph, ...@@ -702,21 +685,21 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indptr.back(), sub_indices.size()); CHECK_EQ(sub_indptr.back(), sub_indices.size());
CHECK_EQ(sub_indices.size(), sub_edge_ids.size()); CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf; NodeFlow nf = NodeFlow::Create();
auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr), auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr),
aten::VecToIdArray(sub_indices), aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids))); aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) { if (neighbor_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr)); nf->graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
} else { } else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr)); nf->graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
} }
nf.node_mapping = aten::VecToIdArray(node_mapping); nf->node_mapping = aten::VecToIdArray(node_mapping);
nf.edge_mapping = aten::VecToIdArray(edge_mapping); nf->edge_mapping = aten::VecToIdArray(edge_mapping);
nf.layer_offsets = aten::VecToIdArray(layer_offsets); nf->layer_offsets = aten::VecToIdArray(layer_offsets);
nf.flow_offsets = aten::VecToIdArray(flow_offsets); nf->flow_offsets = aten::VecToIdArray(flow_offsets);
return nf; return nf;
} }
...@@ -736,7 +719,7 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) { ...@@ -736,7 +719,7 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments // arguments
const GraphHandle ghdl = args[0]; GraphRef g = args[0];
const IdArray seed_nodes = args[1]; const IdArray seed_nodes = args[1];
const int64_t batch_start_id = args[2]; const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3]; const int64_t batch_size = args[3];
...@@ -746,8 +729,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") ...@@ -746,8 +729,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
const std::string neigh_type = args[7]; const std::string neigh_type = args[7];
const bool add_self_loop = args[8]; const bool add_self_loop = args[8];
// process args // process args
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghdl); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes)); CHECK(IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
...@@ -757,7 +739,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") ...@@ -757,7 +739,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
// We need to make sure we have the right CSR before we enter parallel sampling. // We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type); BuildCsr(*gptr, neigh_type);
// generate node flows // generate node flows
std::vector<NodeFlow*> nflows(num_workers); std::vector<NodeFlow> nflows(num_workers);
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < num_workers; i++) { for (int i = 0; i < num_workers; i++) {
// create per-worker seed nodes. // create per-worker seed nodes.
...@@ -767,17 +749,16 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") ...@@ -767,17 +749,16 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
std::vector<dgl_id_t> worker_seeds(end - start); std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end, std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin()); worker_seeds.begin());
nflows[i] = new NodeFlow(); nflows[i] = SamplerOp::NeighborUniformSample(
*nflows[i] = SamplerOp::NeighborUniformSample( gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor, add_self_loop);
gptr, worker_seeds, neigh_type, num_hops, expand_factor, add_self_loop);
} }
*rv = WrapVectorReturn(nflows); *rv = List<NodeFlow>(nflows);
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments // arguments
const GraphHandle ghdl = args[0]; GraphRef g = args[0];
const IdArray seed_nodes = args[1]; const IdArray seed_nodes = args[1];
const int64_t batch_start_id = args[2]; const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3]; const int64_t batch_size = args[3];
...@@ -785,8 +766,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -785,8 +766,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const IdArray layer_sizes = args[5]; const IdArray layer_sizes = args[5];
const std::string neigh_type = args[6]; const std::string neigh_type = args[6];
// process args // process args
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghdl); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes)); CHECK(IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
...@@ -796,7 +776,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -796,7 +776,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// We need to make sure we have the right CSR before we enter parallel sampling. // We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type); BuildCsr(*gptr, neigh_type);
// generate node flows // generate node flows
std::vector<NodeFlow*> nflows(num_workers); std::vector<NodeFlow> nflows(num_workers);
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < num_workers; i++) { for (int i = 0; i < num_workers; i++) {
// create per-worker seed nodes. // create per-worker seed nodes.
...@@ -806,11 +786,10 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -806,11 +786,10 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
std::vector<dgl_id_t> worker_seeds(end - start); std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end, std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin()); worker_seeds.begin());
nflows[i] = new NodeFlow(); nflows[i] = SamplerOp::LayerUniformSample(
*nflows[i] = SamplerOp::LayerUniformSample( gptr.get(), worker_seeds, neigh_type, layer_sizes);
gptr, worker_seeds, neigh_type, layer_sizes);
} }
*rv = WrapVectorReturn(nflows); *rv = List<NodeFlow>(nflows);
}); });
} // namespace dgl } // namespace dgl
...@@ -3,16 +3,13 @@ ...@@ -3,16 +3,13 @@
* \file graph/traversal.cc * \file graph/traversal.cc
* \brief Graph traversal implementation * \brief Graph traversal implementation
*/ */
#include <dgl/packed_func_ext.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include "./traversal.h" #include "./traversal.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using namespace dgl::runtime;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace traverse { namespace traverse {
...@@ -115,7 +112,7 @@ struct Frontiers { ...@@ -115,7 +112,7 @@ struct Frontiers {
std::vector<int64_t> sections; std::vector<int64_t> sections;
}; };
Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) { Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&] (const dgl_id_t v) { };
...@@ -131,17 +128,16 @@ Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) { ...@@ -131,17 +128,16 @@ Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*gptr, src, reversed); const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids); IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections); IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) { Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
// NOTE: std::queue has no top() method. // NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes; std::vector<dgl_id_t> nodes;
...@@ -162,17 +158,16 @@ Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) { ...@@ -162,17 +158,16 @@ Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*gptr, src, reversed); const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray(front.ids); IdArray edge_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections); IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
}); });
Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) { Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&] (const dgl_id_t v) { };
...@@ -188,10 +183,9 @@ Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) { ...@@ -188,10 +183,9 @@ Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
bool reversed = args[1]; bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*gptr, reversed); const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids); IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections); IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
...@@ -200,8 +194,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") ...@@ -200,8 +194,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array."; CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
...@@ -210,7 +203,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -210,7 +203,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); }; auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*gptr, src_data[i], reversed, false, false, visit); DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
} }
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges); IdArray sections = ComputeMergedSections(edges);
...@@ -219,8 +212,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -219,8 +212,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphRef g = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
const bool has_reverse_edge = args[3]; const bool has_reverse_edge = args[3];
...@@ -243,7 +235,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") ...@@ -243,7 +235,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
tags[i].push_back(tag); tags[i].push_back(tag);
} }
}; };
DFSLabeledEdges(*gptr, src_data[i], reversed, DFSLabeledEdges(*g.sptr(), src_data[i], reversed,
has_reverse_edge, has_nontree_edge, visit); has_reverse_edge, has_nontree_edge, visit);
} }
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#ifndef DGL_GRAPH_TRAVERSAL_H_ #ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_ #define DGL_GRAPH_TRAVERSAL_H_
#include <dgl/graph.h> #include <dgl/graph_interface.h>
#include <stack> #include <stack>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -45,7 +45,7 @@ namespace traverse { ...@@ -45,7 +45,7 @@ namespace traverse {
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be made;
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const Graph& graph, void BFSNodes(const GraphInterface& graph,
IdArray source, IdArray source,
bool reversed, bool reversed,
Queue* queue, Queue* queue,
...@@ -63,7 +63,7 @@ void BFSNodes(const Graph& graph, ...@@ -63,7 +63,7 @@ void BFSNodes(const Graph& graph,
} }
make_frontier(); make_frontier();
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec; const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
while (!queue->empty()) { while (!queue->empty()) {
const size_t size = queue->size(); const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -109,7 +109,7 @@ void BFSNodes(const Graph& graph, ...@@ -109,7 +109,7 @@ void BFSNodes(const Graph& graph,
* \param make_frontier The function to indicate that a new frontier can be made; * \param make_frontier The function to indicate that a new frontier can be made;
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(const Graph& graph, void BFSEdges(const GraphInterface& graph,
IdArray source, IdArray source,
bool reversed, bool reversed,
Queue* queue, Queue* queue,
...@@ -126,7 +126,7 @@ void BFSEdges(const Graph& graph, ...@@ -126,7 +126,7 @@ void BFSEdges(const Graph& graph,
} }
make_frontier(); make_frontier();
const auto neighbor_iter = reversed? &Graph::InEdgeVec : &Graph::OutEdgeVec; const auto neighbor_iter = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
while (!queue->empty()) { while (!queue->empty()) {
const size_t size = queue->size(); const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -171,13 +171,13 @@ void BFSEdges(const Graph& graph, ...@@ -171,13 +171,13 @@ void BFSEdges(const Graph& graph,
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be made;
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template<typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const Graph& graph, void TopologicalNodes(const GraphInterface& graph,
bool reversed, bool reversed,
Queue* queue, Queue* queue,
VisitFn visit, VisitFn visit,
FrontierFn make_frontier) { FrontierFn make_frontier) {
const auto get_degree = reversed? &Graph::OutDegree : &Graph::InDegree; const auto get_degree = reversed? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec; const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
uint64_t num_visited_nodes = 0; uint64_t num_visited_nodes = 0;
std::vector<uint64_t> degrees(graph.NumVertices(), 0); std::vector<uint64_t> degrees(graph.NumVertices(), 0);
for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) { for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {
...@@ -237,14 +237,14 @@ enum DFSEdgeTag { ...@@ -237,14 +237,14 @@ enum DFSEdgeTag {
* tag will be given as the arguments. * tag will be given as the arguments.
*/ */
template<typename VisitFn> template<typename VisitFn>
void DFSLabeledEdges(const Graph& graph, void DFSLabeledEdges(const GraphInterface& graph,
dgl_id_t source, dgl_id_t source,
bool reversed, bool reversed,
bool has_reverse_edge, bool has_reverse_edge,
bool has_nontree_edge, bool has_nontree_edge,
VisitFn visit) { VisitFn visit) {
const auto succ = reversed? &Graph::PredVec : &Graph::SuccVec; const auto succ = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
const auto out_edge = reversed? &Graph::InEdgeVec : &Graph::OutEdgeVec; const auto out_edge = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
if ((graph.*succ)(source).size() == 0) { if ((graph.*succ)(source).size() == 0) {
// no out-going edges from the source node // no out-going edges from the source node
......
...@@ -3,17 +3,14 @@ ...@@ -3,17 +3,14 @@
* \file kernel/binary_reduce.cc * \file kernel/binary_reduce.cc
* \brief Binary reduce C APIs and definitions. * \brief Binary reduce C APIs and definitions.
*/ */
#include <dgl/packed_func_ext.h>
#include "./binary_reduce.h" #include "./binary_reduce.h"
#include "./common.h" #include "./common.h"
#include "./binary_reduce_impl_decl.h" #include "./binary_reduce_impl_decl.h"
#include "./utils.h" #include "./utils.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using namespace dgl::runtime;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -273,7 +270,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") ...@@ -273,7 +270,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
std::string op = args[1]; std::string op = args[1];
GraphHandle ghdl = args[2]; GraphRef g = args[2];
int lhs = args[3]; int lhs = args[3];
int rhs = args[4]; int rhs = args[4];
NDArray lhs_data = args[5]; NDArray lhs_data = args[5];
...@@ -283,10 +280,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") ...@@ -283,10 +280,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9]; NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10]; NDArray out_mapping = args[10];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BinaryOpReduce(reducer, op, igptr, BinaryOpReduce(reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data, lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping); lhs_mapping, rhs_mapping, out_mapping);
...@@ -346,7 +342,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -346,7 +342,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
std::string op = args[1]; std::string op = args[1];
GraphHandle ghdl = args[2]; GraphRef g = args[2];
int lhs = args[3]; int lhs = args[3];
int rhs = args[4]; int rhs = args[4];
NDArray lhs_mapping = args[5]; NDArray lhs_mapping = args[5];
...@@ -358,11 +354,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -358,11 +354,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11]; NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12]; NDArray grad_lhs_data = args[12];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardLhsBinaryOpReduce( BackwardLhsBinaryOpReduce(
reducer, op, igptr, reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
...@@ -422,7 +417,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -422,7 +417,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
std::string op = args[1]; std::string op = args[1];
GraphHandle ghdl = args[2]; GraphRef g = args[2];
int lhs = args[3]; int lhs = args[3];
int rhs = args[4]; int rhs = args[4];
NDArray lhs_mapping = args[5]; NDArray lhs_mapping = args[5];
...@@ -434,11 +429,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -434,11 +429,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11]; NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12]; NDArray grad_rhs_data = args[12];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardRhsBinaryOpReduce( BackwardRhsBinaryOpReduce(
reducer, op, igptr, reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
...@@ -469,17 +463,16 @@ void CopyReduce( ...@@ -469,17 +463,16 @@ void CopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
GraphHandle ghdl = args[1]; GraphRef g = args[1];
int target = args[2]; int target = args[2];
NDArray in_data = args[3]; NDArray in_data = args[3];
NDArray out_data = args[4]; NDArray out_data = args[4];
NDArray in_mapping = args[5]; NDArray in_mapping = args[5];
NDArray out_mapping = args[6]; NDArray out_mapping = args[6];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
CopyReduce(reducer, igptr, CopyReduce(reducer, igptr.get(),
static_cast<binary_op::Target>(target), static_cast<binary_op::Target>(target),
in_data, out_data, in_data, out_data,
in_mapping, out_mapping); in_mapping, out_mapping);
...@@ -518,7 +511,7 @@ void BackwardCopyReduce( ...@@ -518,7 +511,7 @@ void BackwardCopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0]; std::string reducer = args[0];
GraphHandle ghdl = args[1]; GraphRef g = args[1];
int target = args[2]; int target = args[2];
NDArray in_data = args[3]; NDArray in_data = args[3];
NDArray out_data = args[4]; NDArray out_data = args[4];
...@@ -527,11 +520,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") ...@@ -527,11 +520,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7]; NDArray in_mapping = args[7];
NDArray out_mapping = args[8]; NDArray out_mapping = args[8];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardCopyReduce( BackwardCopyReduce(
reducer, igptr, static_cast<binary_op::Target>(target), reducer, igptr.get(), static_cast<binary_op::Target>(target),
in_mapping, out_mapping, in_mapping, out_mapping,
in_data, out_data, grad_out_data, in_data, out_data, grad_out_data,
grad_in_data); grad_in_data);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment