Unverified Commit 67dc1197 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Use object system for all CAPIs (#716)

* WIP: using object system for graph

* c++ side refactoring done; compiled

* remove stale apis

* fix bug in DGLGraphCreate; passed test_graph.py

* fix bug in python modify; passed utest for pytorch/cpu

* fix lint

* address comments
parent b0d9e7aa
......@@ -25,25 +25,4 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
return PackedFunc(body);
}
DGL_REGISTER_GLOBAL("_GetVectorWrapperSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const CAPIVectorWrapper* wrapper = static_cast<const CAPIVectorWrapper*>(ptr);
*rv = static_cast<int64_t>(wrapper->pointers.size());
});
DGL_REGISTER_GLOBAL("_GetVectorWrapperData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
*rv = static_cast<void*>(wrapper->pointers.data());
});
DGL_REGISTER_GLOBAL("_FreeVectorWrapper")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
delete wrapper;
});
} // namespace dgl
......@@ -34,9 +34,6 @@ inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
namespace dgl {
// Graph handler type
typedef void* GraphHandle;
// Communicator handler type
typedef void* CommunicatorHandle;
......@@ -73,29 +70,6 @@ dgl::runtime::NDArray CopyVectorToNDArray(
return a;
}
/* A structure used to return a vector of void* pointers. */
struct CAPIVectorWrapper {
// The pointer vector.
std::vector<void*> pointers;
};
/*!
* \brief A helper function used to return vector of pointers from C to frontend.
*
* Note that the function will move the given vector memory into the returned
* wrapper object.
*
* \param vec The given pointer vectors.
* \return A wrapper object containing the given data.
*/
template<typename PType>
CAPIVectorWrapper* WrapVectorReturn(std::vector<PType*> vec) {
CAPIVectorWrapper* wrapper = new CAPIVectorWrapper;
wrapper->pointers.reserve(vec.size());
wrapper->pointers.insert(wrapper->pointers.end(), vec.begin(), vec.end());
return wrapper;
}
} // namespace dgl
#endif // DGL_C_API_COMMON_H_
......@@ -200,7 +200,7 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
}
// O(E*k) pretty slow
Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
......@@ -246,7 +246,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
return EdgeArray{rst_src, rst_dst, rst_eid};
}
Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0];
......@@ -272,7 +272,7 @@ Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
}
// O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -290,7 +290,7 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
}
// O(E)
Graph::EdgeArray Graph::InEdges(IdArray vids) const {
EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -318,7 +318,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
}
// O(E)
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -336,7 +336,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
}
// O(E)
Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
......@@ -364,7 +364,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
}
// O(E*log(E)) if sort is required; otherwise, O(E)
Graph::EdgeArray Graph::Edges(const std::string &order) const {
EdgeArray Graph::Edges(const std::string &order) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -585,9 +585,4 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
}
}
GraphPtr Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return nullptr;
}
} // namespace dgl
This diff is collapsed.
This diff is collapsed.
......@@ -4,6 +4,7 @@
* \brief DGL immutable graph index implementation
*/
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <string.h>
#include <bitset>
......@@ -12,8 +13,14 @@
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace {
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
#ifndef _WIN32
......@@ -124,22 +131,22 @@ bool CSR::IsMultigraph() const {
});
}
CSR::EdgeArray CSR::OutEdges(dgl_id_t vid) const {
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return CSR::EdgeArray{ret_src, ret_dst, ret_eid};
return EdgeArray{ret_src, ret_dst, ret_eid};
}
CSR::EdgeArray CSR::OutEdges(IdArray vids) const {
EdgeArray CSR::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so
// we need to recover it using an index select.
auto row = aten::IndexSelect(vids, coosubmat.row);
return CSR::EdgeArray{row, coosubmat.col, coosubmat.data};
return EdgeArray{row, coosubmat.col, coosubmat.data};
}
DegreeArray CSR::OutDegrees(IdArray vids) const {
......@@ -171,17 +178,17 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
return aten::CSRGetData(adj_, src, dst);
}
CSR::EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
return CSR::EdgeArray{arrs[0], arrs[1], arrs[2]};
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}
CSR::EdgeArray CSR::Edges(const std::string &order) const {
EdgeArray CSR::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\".";
const auto& coo = aten::CSRToCOO(adj_, false);
return CSR::EdgeArray{coo.row, coo.col, coo.data};
return EdgeArray{coo.row, coo.col, coo.data};
}
Subgraph CSR::VertexSubgraph(IdArray vids) const {
......@@ -289,14 +296,14 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}
COO::EdgeArray COO::FindEdges(IdArray eids) const {
EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
eids};
}
COO::EdgeArray COO::Edges(const std::string &order) const {
EdgeArray COO::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \""
<< order << "\".";
......@@ -411,7 +418,7 @@ COOPtr ImmutableGraph::GetCOO() const {
return coo_;
}
ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const {
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
if (order.empty()) {
// arbitrary order
if (in_csr_) {
......@@ -467,53 +474,177 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
}
}
ImmutableGraph ImmutableGraph::ToImmutable(const GraphInterface* graph) {
const ImmutableGraph* ig = dynamic_cast<const ImmutableGraph*>(graph);
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst) {
COOPtr coo(new COO(num_vertices, src, dst));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
COOPtr coo(new COO(num_vertices, src, dst, multigraph));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
if (ig) {
return *ig;
return ig;
} else {
const auto& adj = graph->GetAdj(true, "csr");
CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
return ImmutableGraph(nullptr, csr);
return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
}
}
ImmutableGraph ImmutableGraph::CopyTo(const DLContext& ctx) const {
if (ctx == Context()) {
return *this;
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
// TODO(minjie): since we don't have GPU implementation of COO<->CSR,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyTo(ctx)));
return ImmutableGraph(new_incsr, new_outcsr);
CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
}
ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir,
const std::string &name) const {
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
const std::string &edge_dir, const std::string &name) {
CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, edge_dir);
if (edge_dir == std::string("in"))
new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name)));
new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
else if (edge_dir == std::string("out"))
new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraph(new_incsr, new_outcsr, name);
new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
}
ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) {
return g;
} else {
// TODO(minjie): since we don't have int32 operations,
// we make sure that this graph (on CPU) has materialized CSR,
// and then copy them to other context (usually GPU). This should
// be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->AsNumBits(bits)));
return ImmutableGraph(new_incsr, new_outcsr);
CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
}
}
ImmutableGraphPtr ImmutableGraph::Reverse() const {
if (coo_) {
return ImmutableGraphPtr(new ImmutableGraph(
out_csr_, in_csr_, coo_->Transpose()));
} else {
return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
}
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const int device_type = args[1];
const int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyTo(ig, ctx);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
std::string edge_dir = args[1];
std::string name = args[2];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int bits = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::AsNumBits(ig, bits);
});
} // namespace dgl
......@@ -4,6 +4,8 @@
* \brief DGL networking related APIs
*/
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include "./network.h"
#include "./network/communicator.h"
#include "./network/socket_communicator.h"
......@@ -11,11 +13,7 @@
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace network {
......@@ -84,12 +82,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
GraphHandle ghandle = args[2];
// TODO(minjie): could simply use NodeFlow nf = args[2];
GraphRef g = args[2];
const IdArray node_mapping = args[3];
const IdArray edge_mapping = args[4];
const IdArray layer_offsets = args[5];
const IdArray flow_offsets = args[6];
ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle);
auto ptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ptr) << "only immutable graph is allowed in send/recv";
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR();
// Write control message
......@@ -159,7 +159,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer;
if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow();
NodeFlow nf = NodeFlow::Create();
CSRPtr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
......@@ -169,9 +169,9 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
List<NodeFlow> subgs;
subgs.push_back(nf);
*rv = subgs;
} else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL;
} else {
......
......@@ -5,9 +5,10 @@
*/
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/nodeflow.h>
#include <string.h>
#include <string>
#include "../c_api_common.h"
......@@ -78,15 +79,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
std::string format = args[1];
int64_t layer0_size = args[2];
int64_t start = args[3];
int64_t end = args[4];
const bool remap = args[5];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const ImmutableGraph* gptr = dynamic_cast<const ImmutableGraph*>(ptr);
auto res = GetNodeFlowSlice(*gptr, format, layer0_size, start, end, remap);
auto ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
*rv = ConvertNDArrayVectorToPackedFunc(res);
});
......
......@@ -7,6 +7,7 @@
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <algorithm>
#include <cstdlib>
#include <cmath>
......@@ -14,11 +15,7 @@
#include <functional>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
......@@ -218,43 +215,40 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
const IdArray seeds = args[1];
const int num_traces = args[2];
const int num_hops = args[3];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
*rv = RandomWalk(ptr, seeds, num_traces, num_hops);
*rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops);
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
const IdArray seeds = args[1];
const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc(
RandomWalkWithRestart(gptr, seeds, restart_prob, visit_threshold_per_seed,
RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
GraphRef g = args[0];
const IdArray seeds = args[1];
const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc(
BipartiteSingleSidedRandomWalkWithRestart(
gptr, seeds, restart_prob, visit_threshold_per_seed,
g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
......
......@@ -3,10 +3,10 @@
* \file graph/sampler.cc
* \brief DGL sampler implementation
*/
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dmlc/omp.h>
#include <algorithm>
#include <cstdlib>
......@@ -14,11 +14,7 @@
#include <numeric>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
......@@ -246,17 +242,17 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type,
int64_t num_edges, int num_hops, bool is_multigraph) {
NodeFlow nf;
NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size();
nf.node_mapping = aten::NewIdArray(num_vertices);
nf.edge_mapping = aten::NewIdArray(num_edges);
nf.layer_offsets = aten::NewIdArray(num_hops + 1);
nf.flow_offsets = aten::NewIdArray(num_hops);
nf->node_mapping = aten::NewIdArray(num_vertices);
nf->edge_mapping = aten::NewIdArray(num_edges);
nf->layer_offsets = aten::NewIdArray(num_hops + 1);
nf->flow_offsets = aten::NewIdArray(num_hops);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data);
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf.flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf.edge_mapping->data);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf->node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf->layer_offsets->data);
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data);
// Construct sub_csr_graph
// TODO(minjie): is nodeflow a multigraph?
......@@ -364,9 +360,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::iota(eid_out, eid_out + num_edges, 0);
if (edge_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr));
nf->graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr));
} else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr));
nf->graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr));
}
return nf;
......@@ -491,47 +487,34 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
GraphInterface* gptr = nflow->graph->Reset();
*rv = gptr;
NodeFlow nflow = args[0];
*rv = nflow->graph;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetNodeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->node_mapping;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetEdgeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->edge_mapping;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetLayerOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->layer_offsets;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
const NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
NodeFlow nflow = args[0];
*rv = nflow->flow_offsets;
});
DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowFree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
void* ptr = args[0];
NodeFlow* nflow = static_cast<NodeFlow*>(ptr);
delete nflow;
});
NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const std::string &edge_type,
......@@ -702,21 +685,21 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indptr.back(), sub_indices.size());
CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf;
NodeFlow nf = NodeFlow::Create();
auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr),
aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
nf->graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
} else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
nf->graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
}
nf.node_mapping = aten::VecToIdArray(node_mapping);
nf.edge_mapping = aten::VecToIdArray(edge_mapping);
nf.layer_offsets = aten::VecToIdArray(layer_offsets);
nf.flow_offsets = aten::VecToIdArray(flow_offsets);
nf->node_mapping = aten::VecToIdArray(node_mapping);
nf->edge_mapping = aten::VecToIdArray(edge_mapping);
nf->layer_offsets = aten::VecToIdArray(layer_offsets);
nf->flow_offsets = aten::VecToIdArray(flow_offsets);
return nf;
}
......@@ -736,7 +719,7 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
const GraphHandle ghdl = args[0];
GraphRef g = args[0];
const IdArray seed_nodes = args[1];
const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3];
......@@ -746,8 +729,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
const std::string neigh_type = args[7];
const bool add_self_loop = args[8];
// process args
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghdl);
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
......@@ -757,7 +739,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow*> nflows(num_workers);
std::vector<NodeFlow> nflows(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
// create per-worker seed nodes.
......@@ -767,17 +749,16 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin());
nflows[i] = new NodeFlow();
*nflows[i] = SamplerOp::NeighborUniformSample(
gptr, worker_seeds, neigh_type, num_hops, expand_factor, add_self_loop);
nflows[i] = SamplerOp::NeighborUniformSample(
gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor, add_self_loop);
}
*rv = WrapVectorReturn(nflows);
*rv = List<NodeFlow>(nflows);
});
DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
const GraphHandle ghdl = args[0];
GraphRef g = args[0];
const IdArray seed_nodes = args[1];
const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3];
......@@ -785,8 +766,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const IdArray layer_sizes = args[5];
const std::string neigh_type = args[6];
// process args
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghdl);
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
......@@ -796,7 +776,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow*> nflows(num_workers);
std::vector<NodeFlow> nflows(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
// create per-worker seed nodes.
......@@ -806,11 +786,10 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin());
nflows[i] = new NodeFlow();
*nflows[i] = SamplerOp::LayerUniformSample(
gptr, worker_seeds, neigh_type, layer_sizes);
nflows[i] = SamplerOp::LayerUniformSample(
gptr.get(), worker_seeds, neigh_type, layer_sizes);
}
*rv = WrapVectorReturn(nflows);
*rv = List<NodeFlow>(nflows);
});
} // namespace dgl
......@@ -3,16 +3,13 @@
* \file graph/traversal.cc
* \brief Graph traversal implementation
*/
#include <dgl/packed_func_ext.h>
#include <algorithm>
#include <queue>
#include "./traversal.h"
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace traverse {
......@@ -115,7 +112,7 @@ struct Frontiers {
std::vector<int64_t> sections;
};
Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
......@@ -131,17 +128,16 @@ Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*gptr, src, reversed);
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front;
// NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes;
......@@ -162,17 +158,16 @@ Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*gptr, src, reversed);
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
});
Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) {
Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
......@@ -188,10 +183,9 @@ Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*gptr, reversed);
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray(front.ids);
IdArray sections = CopyVectorToNDArray(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
......@@ -200,8 +194,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array.";
......@@ -210,7 +203,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*gptr, src_data[i], reversed, false, false, visit);
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
}
IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges);
......@@ -219,8 +212,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
......@@ -243,7 +235,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
tags[i].push_back(tag);
}
};
DFSLabeledEdges(*gptr, src_data[i], reversed,
DFSLabeledEdges(*g.sptr(), src_data[i], reversed,
has_reverse_edge, has_nontree_edge, visit);
}
......
......@@ -11,7 +11,7 @@
#ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_
#include <dgl/graph.h>
#include <dgl/graph_interface.h>
#include <stack>
#include <tuple>
#include <vector>
......@@ -45,7 +45,7 @@ namespace traverse {
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const Graph& graph,
void BFSNodes(const GraphInterface& graph,
IdArray source,
bool reversed,
Queue* queue,
......@@ -63,7 +63,7 @@ void BFSNodes(const Graph& graph,
}
make_frontier();
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec;
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
......@@ -109,7 +109,7 @@ void BFSNodes(const Graph& graph,
* \param make_frontier The function to indicate that a new frontier can be made;
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(const Graph& graph,
void BFSEdges(const GraphInterface& graph,
IdArray source,
bool reversed,
Queue* queue,
......@@ -126,7 +126,7 @@ void BFSEdges(const Graph& graph,
}
make_frontier();
const auto neighbor_iter = reversed? &Graph::InEdgeVec : &Graph::OutEdgeVec;
const auto neighbor_iter = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
......@@ -171,13 +171,13 @@ void BFSEdges(const Graph& graph,
* \param make_frontier The function to indicate that a new froniter can be made;
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const Graph& graph,
void TopologicalNodes(const GraphInterface& graph,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const auto get_degree = reversed? &Graph::OutDegree : &Graph::InDegree;
const auto neighbor_iter = reversed? &Graph::PredVec : &Graph::SuccVec;
const auto get_degree = reversed? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
uint64_t num_visited_nodes = 0;
std::vector<uint64_t> degrees(graph.NumVertices(), 0);
for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {
......@@ -237,14 +237,14 @@ enum DFSEdgeTag {
* tag will be given as the arguments.
*/
template<typename VisitFn>
void DFSLabeledEdges(const Graph& graph,
void DFSLabeledEdges(const GraphInterface& graph,
dgl_id_t source,
bool reversed,
bool has_reverse_edge,
bool has_nontree_edge,
VisitFn visit) {
const auto succ = reversed? &Graph::PredVec : &Graph::SuccVec;
const auto out_edge = reversed? &Graph::InEdgeVec : &Graph::OutEdgeVec;
const auto succ = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
const auto out_edge = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
if ((graph.*succ)(source).size() == 0) {
// no out-going edges from the source node
......
......@@ -3,17 +3,14 @@
* \file kernel/binary_reduce.cc
* \brief Binary reduce C APIs and definitions.
*/
#include <dgl/packed_func_ext.h>
#include "./binary_reduce.h"
#include "./common.h"
#include "./binary_reduce_impl_decl.h"
#include "./utils.h"
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using namespace dgl::runtime;
namespace dgl {
namespace kernel {
......@@ -273,7 +270,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphHandle ghdl = args[2];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_data = args[5];
......@@ -283,10 +280,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BinaryOpReduce(reducer, op, igptr,
BinaryOpReduce(reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping);
......@@ -346,7 +342,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphHandle ghdl = args[2];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_mapping = args[5];
......@@ -358,11 +354,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardLhsBinaryOpReduce(
reducer, op, igptr,
reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
......@@ -422,7 +417,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
std::string op = args[1];
GraphHandle ghdl = args[2];
GraphRef g = args[2];
int lhs = args[3];
int rhs = args[4];
NDArray lhs_mapping = args[5];
......@@ -434,11 +429,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardRhsBinaryOpReduce(
reducer, op, igptr,
reducer, op, igptr.get(),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
......@@ -469,17 +463,16 @@ void CopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
GraphHandle ghdl = args[1];
GraphRef g = args[1];
int target = args[2];
NDArray in_data = args[3];
NDArray out_data = args[4];
NDArray in_mapping = args[5];
NDArray out_mapping = args[6];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
CopyReduce(reducer, igptr,
CopyReduce(reducer, igptr.get(),
static_cast<binary_op::Target>(target),
in_data, out_data,
in_mapping, out_mapping);
......@@ -518,7 +511,7 @@ void BackwardCopyReduce(
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string reducer = args[0];
GraphHandle ghdl = args[1];
GraphRef g = args[1];
int target = args[2];
NDArray in_data = args[3];
NDArray out_data = args[4];
......@@ -527,11 +520,10 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7];
NDArray out_mapping = args[8];
GraphInterface* gptr = static_cast<GraphInterface*>(ghdl);
const ImmutableGraph* igptr = dynamic_cast<ImmutableGraph*>(gptr);
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BackwardCopyReduce(
reducer, igptr, static_cast<binary_op::Target>(target),
reducer, igptr.get(), static_cast<binary_op::Target>(target),
in_mapping, out_mapping,
in_data, out_data, grad_out_data,
grad_in_data);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment