Unverified Commit 788420df authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Refactor] Decoupling ImmutableGraph from Kernels (#749)

* csr interface

* csr wrapper for immutable graph

* lint

* silly fix

* docstring
parent 35bed2a9
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
* \brief Binary reduce C APIs and definitions. * \brief Binary reduce C APIs and definitions.
*/ */
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include "./binary_reduce.h" #include "./binary_reduce.h"
#include "./common.h" #include "./common.h"
#include "./binary_reduce_impl_decl.h" #include "./binary_reduce_impl_decl.h"
#include "./utils.h" #include "./utils.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./csr_interface.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -201,6 +203,31 @@ inline bool NeedSwitchOrder(const std::string& op, ...@@ -201,6 +203,31 @@ inline bool NeedSwitchOrder(const std::string& op,
&& lhs > rhs; && lhs > rhs;
} }
class ImmutableGraphCSRWrapper : public CSRWrapper {
public:
explicit ImmutableGraphCSRWrapper(const ImmutableGraph* graph) :
gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override {
return gptr_->GetInCSR()->ToCSRMatrix();
}
aten::CSRMatrix GetOutCSRMatrix() const override {
return gptr_->GetOutCSR()->ToCSRMatrix();
}
DGLContext Context() const override {
return gptr_->Context();
}
int NumBits() const override {
return gptr_->NumBits();
}
private:
const ImmutableGraph* gptr_;
};
} // namespace } // namespace
...@@ -226,18 +253,18 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape") ...@@ -226,18 +253,18 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape")
void BinaryOpReduce( void BinaryOpReduce(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_data, NDArray rhs_data, NDArray lhs_data, NDArray rhs_data,
NDArray out_data, NDArray out_data,
NDArray lhs_mapping, NDArray rhs_mapping, NDArray lhs_mapping, NDArray rhs_mapping,
NDArray out_mapping) { NDArray out_mapping) {
const auto& ctx = graph->Context(); const auto& ctx = graph.Context();
// sanity check // sanity check
CheckCtx(ctx, CheckCtx(ctx,
{lhs_data, rhs_data, out_data, lhs_mapping, rhs_mapping, out_mapping}, {lhs_data, rhs_data, out_data, lhs_mapping, rhs_mapping, out_mapping},
{"lhs_data", "rhs_data", "out_data", "lhs_mapping", "rhs_mapping", "out_mapping"}); {"lhs_data", "rhs_data", "out_data", "lhs_mapping", "rhs_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(), CheckIdArray(graph.NumBits(),
{lhs_mapping, rhs_mapping, out_mapping}, {lhs_mapping, rhs_mapping, out_mapping},
{"lhs_mapping", "rhs_mapping", "out_mapping"}); {"lhs_mapping", "rhs_mapping", "out_mapping"});
// Switch order for commutative operation // Switch order for commutative operation
...@@ -282,7 +309,8 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") ...@@ -282,7 +309,8 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BinaryOpReduce(reducer, op, igptr.get(), ImmutableGraphCSRWrapper wrapper(igptr.get());
BinaryOpReduce(reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data, lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping); lhs_mapping, rhs_mapping, out_mapping);
...@@ -291,7 +319,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") ...@@ -291,7 +319,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
void BackwardLhsBinaryOpReduce( void BackwardLhsBinaryOpReduce(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping, NDArray lhs_mapping,
NDArray rhs_mapping, NDArray rhs_mapping,
...@@ -301,14 +329,14 @@ void BackwardLhsBinaryOpReduce( ...@@ -301,14 +329,14 @@ void BackwardLhsBinaryOpReduce(
NDArray out_data, NDArray out_data,
NDArray grad_out_data, NDArray grad_out_data,
NDArray grad_lhs_data) { NDArray grad_lhs_data) {
const auto& ctx = graph->Context(); const auto& ctx = graph.Context();
// sanity check // sanity check
CheckCtx(ctx, CheckCtx(ctx,
{lhs_data, rhs_data, out_data, grad_out_data, grad_lhs_data, {lhs_data, rhs_data, out_data, grad_out_data, grad_lhs_data,
lhs_mapping, rhs_mapping, out_mapping}, lhs_mapping, rhs_mapping, out_mapping},
{"lhs_data", "rhs_data", "out_data", "grad_out_data", "grad_lhs_data", {"lhs_data", "rhs_data", "out_data", "grad_out_data", "grad_lhs_data",
"lhs_mapping", "rhs_mapping", "out_mapping"}); "lhs_mapping", "rhs_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(), CheckIdArray(graph.NumBits(),
{lhs_mapping, rhs_mapping, out_mapping}, {lhs_mapping, rhs_mapping, out_mapping},
{"lhs_mapping", "rhs_mapping", "out_mapping"}); {"lhs_mapping", "rhs_mapping", "out_mapping"});
// Switch order for commutative operation // Switch order for commutative operation
...@@ -356,8 +384,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -356,8 +384,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardLhsBinaryOpReduce( BackwardLhsBinaryOpReduce(
reducer, op, igptr.get(), reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
...@@ -367,7 +396,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -367,7 +396,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
void BackwardRhsBinaryOpReduce( void BackwardRhsBinaryOpReduce(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping, NDArray lhs_mapping,
NDArray rhs_mapping, NDArray rhs_mapping,
...@@ -377,14 +406,14 @@ void BackwardRhsBinaryOpReduce( ...@@ -377,14 +406,14 @@ void BackwardRhsBinaryOpReduce(
NDArray out_data, NDArray out_data,
NDArray grad_out_data, NDArray grad_out_data,
NDArray grad_rhs_data) { NDArray grad_rhs_data) {
const auto& ctx = graph->Context(); const auto& ctx = graph.Context();
// sanity check // sanity check
CheckCtx(ctx, CheckCtx(ctx,
{lhs_data, rhs_data, out_data, grad_out_data, grad_rhs_data, {lhs_data, rhs_data, out_data, grad_out_data, grad_rhs_data,
lhs_mapping, rhs_mapping, out_mapping}, lhs_mapping, rhs_mapping, out_mapping},
{"lhs_data", "rhs_data", "out_data", "grad_out_data", "grad_rhs_data", {"lhs_data", "rhs_data", "out_data", "grad_out_data", "grad_rhs_data",
"lhs_mapping", "rhs_mapping", "out_mapping"}); "lhs_mapping", "rhs_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(), CheckIdArray(graph.NumBits(),
{lhs_mapping, rhs_mapping, out_mapping}, {lhs_mapping, rhs_mapping, out_mapping},
{"lhs_mapping", "rhs_mapping", "out_mapping"}); {"lhs_mapping", "rhs_mapping", "out_mapping"});
if (NeedSwitchOrder(op, lhs, rhs)) { if (NeedSwitchOrder(op, lhs, rhs)) {
...@@ -431,8 +460,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -431,8 +460,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardRhsBinaryOpReduce( BackwardRhsBinaryOpReduce(
reducer, op, igptr.get(), reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
...@@ -441,16 +471,16 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -441,16 +471,16 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
void CopyReduce( void CopyReduce(
const std::string& reducer, const std::string& reducer,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target target, binary_op::Target target,
NDArray in_data, NDArray out_data, NDArray in_data, NDArray out_data,
NDArray in_mapping, NDArray out_mapping) { NDArray in_mapping, NDArray out_mapping) {
const auto& ctx = graph->Context(); const auto& ctx = graph.Context();
// sanity check // sanity check
CheckCtx(ctx, CheckCtx(ctx,
{in_data, out_data, in_mapping, out_mapping}, {in_data, out_data, in_mapping, out_mapping},
{"in_data", "out_data", "in_mapping", "out_mapping"}); {"in_data", "out_data", "in_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(), CheckIdArray(graph.NumBits(),
{in_mapping, out_mapping}, {in_mapping, out_mapping},
{"in_mapping", "out_mapping"}); {"in_mapping", "out_mapping"});
DGL_XPU_SWITCH(ctx.device_type, BinaryReduceImpl, DGL_XPU_SWITCH(ctx.device_type, BinaryReduceImpl,
...@@ -472,7 +502,8 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce") ...@@ -472,7 +502,8 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
CopyReduce(reducer, igptr.get(), ImmutableGraphCSRWrapper wrapper(igptr.get());
CopyReduce(reducer, wrapper,
static_cast<binary_op::Target>(target), static_cast<binary_op::Target>(target),
in_data, out_data, in_data, out_data,
in_mapping, out_mapping); in_mapping, out_mapping);
...@@ -480,7 +511,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce") ...@@ -480,7 +511,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
void BackwardCopyReduce( void BackwardCopyReduce(
const std::string& reducer, const std::string& reducer,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target target, binary_op::Target target,
NDArray in_mapping, NDArray in_mapping,
NDArray out_mapping, NDArray out_mapping,
...@@ -488,12 +519,12 @@ void BackwardCopyReduce( ...@@ -488,12 +519,12 @@ void BackwardCopyReduce(
NDArray out_data, NDArray out_data,
NDArray grad_out_data, NDArray grad_out_data,
NDArray grad_in_data) { NDArray grad_in_data) {
const auto& ctx = graph->Context(); const auto& ctx = graph.Context();
// sanity check // sanity check
CheckCtx(ctx, CheckCtx(ctx,
{in_data, out_data, grad_out_data, grad_in_data, in_mapping, out_mapping}, {in_data, out_data, grad_out_data, grad_in_data, in_mapping, out_mapping},
{"in_data", "out_data", "grad_out_data", "grad_in_data", "in_mapping", "out_mapping"}); {"in_data", "out_data", "grad_out_data", "grad_in_data", "in_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(), CheckIdArray(graph.NumBits(),
{in_mapping, out_mapping}, {in_mapping, out_mapping},
{"in_mapping", "out_mapping"}); {"in_mapping", "out_mapping"});
if (!utils::IsNoneArray(out_mapping)) { if (!utils::IsNoneArray(out_mapping)) {
...@@ -522,8 +553,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") ...@@ -522,8 +553,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph."; CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardCopyReduce( BackwardCopyReduce(
reducer, igptr.get(), static_cast<binary_op::Target>(target), reducer, wrapper, static_cast<binary_op::Target>(target),
in_mapping, out_mapping, in_mapping, out_mapping,
in_data, out_data, grad_out_data, in_data, out_data, grad_out_data,
grad_in_data); grad_in_data);
......
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#define DGL_KERNEL_BINARY_REDUCE_H_ #define DGL_KERNEL_BINARY_REDUCE_H_
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/immutable_graph.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include "./binary_reduce_common.h" #include "./binary_reduce_common.h"
#include "./csr_interface.h"
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -83,7 +83,7 @@ std::vector<int64_t> InferBinaryFeatureShape( ...@@ -83,7 +83,7 @@ std::vector<int64_t> InferBinaryFeatureShape(
void BinaryOpReduce( void BinaryOpReduce(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -126,7 +126,7 @@ void BinaryOpReduce( ...@@ -126,7 +126,7 @@ void BinaryOpReduce(
void BackwardLhsBinaryOpReduce( void BackwardLhsBinaryOpReduce(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray lhs_mapping,
runtime::NDArray rhs_mapping, runtime::NDArray rhs_mapping,
...@@ -173,7 +173,7 @@ void BackwardLhsBinaryOpReduce( ...@@ -173,7 +173,7 @@ void BackwardLhsBinaryOpReduce(
void BackwardRhsBinaryOpReduce( void BackwardRhsBinaryOpReduce(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray lhs_mapping,
runtime::NDArray rhs_mapping, runtime::NDArray rhs_mapping,
...@@ -213,7 +213,7 @@ void BackwardRhsBinaryOpReduce( ...@@ -213,7 +213,7 @@ void BackwardRhsBinaryOpReduce(
*/ */
void CopyReduce( void CopyReduce(
const std::string& reducer, const std::string& reducer,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target target, binary_op::Target target,
runtime::NDArray in_data, runtime::NDArray out_data, runtime::NDArray in_data, runtime::NDArray out_data,
runtime::NDArray in_mapping, runtime::NDArray out_mapping); runtime::NDArray in_mapping, runtime::NDArray out_mapping);
...@@ -236,7 +236,7 @@ void CopyReduce( ...@@ -236,7 +236,7 @@ void CopyReduce(
*/ */
void BackwardCopyReduce( void BackwardCopyReduce(
const std::string& reducer, const std::string& reducer,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target target, binary_op::Target target,
runtime::NDArray in_mapping, runtime::NDArray in_mapping,
runtime::NDArray out_mapping, runtime::NDArray out_mapping,
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <minigun/minigun.h> #include <minigun/minigun.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/immutable_graph.h>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
...@@ -18,6 +17,7 @@ ...@@ -18,6 +17,7 @@
#endif #endif
#include "./binary_reduce.h" #include "./binary_reduce.h"
#include "./binary_reduce_impl_decl.h" #include "./binary_reduce_impl_decl.h"
#include "./csr_interface.h"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
...@@ -58,7 +58,7 @@ template <int XPU> ...@@ -58,7 +58,7 @@ template <int XPU>
void BinaryReduceImpl( void BinaryReduceImpl(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -88,7 +88,7 @@ void BinaryReduceImpl( ...@@ -88,7 +88,7 @@ void BinaryReduceImpl(
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
} }
const DLDataType& dtype = out_data->dtype; const DLDataType& dtype = out_data->dtype;
const auto bits = graph->NumBits(); const auto bits = graph.NumBits();
DGL_DTYPE_SWITCH(dtype, DType, { DGL_DTYPE_SWITCH(dtype, DType, {
DGL_IDX_TYPE_SWITCH(bits, Idx, { DGL_IDX_TYPE_SWITCH(bits, Idx, {
REDUCER_SWITCH(reducer, XPU, DType, Reducer, { REDUCER_SWITCH(reducer, XPU, DType, Reducer, {
...@@ -151,7 +151,7 @@ template <int XPU> ...@@ -151,7 +151,7 @@ template <int XPU>
void BackwardBinaryReduceImpl( void BackwardBinaryReduceImpl(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
...@@ -181,7 +181,7 @@ void BackwardBinaryReduceImpl( ...@@ -181,7 +181,7 @@ void BackwardBinaryReduceImpl(
const DLDataType& dtype = out_data->dtype; const DLDataType& dtype = out_data->dtype;
const bool req_lhs = !utils::IsNoneArray(grad_lhs_data); const bool req_lhs = !utils::IsNoneArray(grad_lhs_data);
const bool req_rhs = !utils::IsNoneArray(grad_rhs_data); const bool req_rhs = !utils::IsNoneArray(grad_rhs_data);
const auto bits = graph->NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide // TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
...@@ -250,7 +250,7 @@ void BinaryReduceBcastImpl( ...@@ -250,7 +250,7 @@ void BinaryReduceBcastImpl(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target lhs,
binary_op::Target rhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray lhs_data,
...@@ -279,7 +279,7 @@ void BinaryReduceBcastImpl( ...@@ -279,7 +279,7 @@ void BinaryReduceBcastImpl(
const DLDataType& dtype = out_data->dtype; const DLDataType& dtype = out_data->dtype;
const int bcast_ndim = info.out_shape.size(); const int bcast_ndim = info.out_shape.size();
const auto bits = graph->NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide // TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
...@@ -359,7 +359,7 @@ void BackwardBinaryReduceBcastImpl( ...@@ -359,7 +359,7 @@ void BackwardBinaryReduceBcastImpl(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs_tgt, binary_op::Target rhs_tgt, binary_op::Target lhs_tgt, binary_op::Target rhs_tgt,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out, runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out,
...@@ -386,7 +386,7 @@ void BackwardBinaryReduceBcastImpl( ...@@ -386,7 +386,7 @@ void BackwardBinaryReduceBcastImpl(
const int bcast_ndim = info.out_shape.size(); const int bcast_ndim = info.out_shape.size();
const bool req_lhs = !utils::IsNoneArray(grad_lhs); const bool req_lhs = !utils::IsNoneArray(grad_lhs);
const bool req_rhs = !utils::IsNoneArray(grad_rhs); const bool req_rhs = !utils::IsNoneArray(grad_rhs);
const auto bits = graph->NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide // TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported."; LOG(FATAL) << "reduce mean is not supported.";
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <string> #include <string>
#include "./binary_reduce_common.h" #include "./binary_reduce_common.h"
#include "./csr_interface.h"
namespace minigun { namespace minigun {
namespace advance { namespace advance {
...@@ -21,9 +22,6 @@ struct RuntimeConfig; ...@@ -21,9 +22,6 @@ struct RuntimeConfig;
namespace dgl { namespace dgl {
// forward declaration
class ImmutableGraph;
namespace kernel { namespace kernel {
// forward declaration // forward declaration
...@@ -81,7 +79,7 @@ template <int XPU, typename Idx, typename DType, ...@@ -81,7 +79,7 @@ template <int XPU, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBinaryReduce( void CallBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
GData<Idx, DType>* gdata); GData<Idx, DType>* gdata);
/*! /*!
...@@ -107,7 +105,7 @@ template <int XPU> ...@@ -107,7 +105,7 @@ template <int XPU>
void BinaryReduceImpl( void BinaryReduceImpl(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping); runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping);
...@@ -168,7 +166,7 @@ template <int XPU, int Mode, typename Idx, typename DType, ...@@ -168,7 +166,7 @@ template <int XPU, int Mode, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduce( void CallBackwardBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardGData<Idx, DType>* gdata); BackwardGData<Idx, DType>* gdata);
/*! /*!
...@@ -196,7 +194,7 @@ template <int XPU> ...@@ -196,7 +194,7 @@ template <int XPU>
void BackwardBinaryReduceImpl( void BackwardBinaryReduceImpl(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
...@@ -269,7 +267,7 @@ template <int XPU, int NDim, typename Idx, typename DType, ...@@ -269,7 +267,7 @@ template <int XPU, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBinaryReduceBcast( void CallBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BcastGData<NDim, Idx, DType>* gdata); BcastGData<NDim, Idx, DType>* gdata);
/*! /*!
...@@ -296,7 +294,7 @@ void BinaryReduceBcastImpl( ...@@ -296,7 +294,7 @@ void BinaryReduceBcastImpl(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -370,7 +368,7 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType, ...@@ -370,7 +368,7 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduceBcast( void CallBackwardBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardBcastGData<NDim, Idx, DType>* gdata); BackwardBcastGData<NDim, Idx, DType>* gdata);
/*! /*!
...@@ -399,7 +397,7 @@ void BackwardBinaryReduceBcastImpl( ...@@ -399,7 +397,7 @@ void BackwardBinaryReduceBcastImpl(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
......
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
#define DGL_KERNEL_CPU_BACKWARD_BINARY_REDUCE_IMPL_H_ #define DGL_KERNEL_CPU_BACKWARD_BINARY_REDUCE_IMPL_H_
#include <minigun/minigun.h> #include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include "../binary_reduce_impl_decl.h" #include "../binary_reduce_impl_decl.h"
#include "../utils.h" #include "../utils.h"
#include "./functor.h" #include "./functor.h"
#include "../csr_interface.h"
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -170,14 +170,14 @@ template <int XPU, int Mode, typename Idx, typename DType, ...@@ -170,14 +170,14 @@ template <int XPU, int Mode, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduce( void CallBackwardBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardGData<Idx, DType>* gdata) { BackwardGData<Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src. // For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the // This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce // gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add. // competition of atomic add.
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cpu::BackwardFunctorsTempl<Idx, DType, typedef cpu::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type, typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type, typename SwitchSrcDst<RightSelector>::Type,
...@@ -188,15 +188,15 @@ void CallBackwardBinaryReduce( ...@@ -188,15 +188,15 @@ void CallBackwardBinaryReduce(
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) { && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) { && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, BackwardGData<Idx, DType>, UDF>( minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, BackwardGData<Idx, DType>, UDF>(
...@@ -211,7 +211,7 @@ void CallBackwardBinaryReduce( ...@@ -211,7 +211,7 @@ void CallBackwardBinaryReduce(
lhs_tgt, rhs_tgt, \ lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \ op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
BackwardGData<IDX, dtype>* gdata); BackwardGData<IDX, dtype>* gdata);
// Template implementation of BackwardBinaryReduce with broadcasting operator. // Template implementation of BackwardBinaryReduce with broadcasting operator.
...@@ -220,14 +220,14 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType, ...@@ -220,14 +220,14 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduceBcast( void CallBackwardBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardBcastGData<NDim, Idx, DType>* gdata) { BackwardBcastGData<NDim, Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src. // For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the // This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce // gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add. // competition of atomic add.
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cpu::BackwardFunctorsTempl<Idx, DType, typedef cpu::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type, typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type, typename SwitchSrcDst<RightSelector>::Type,
...@@ -238,15 +238,15 @@ void CallBackwardBinaryReduceBcast( ...@@ -238,15 +238,15 @@ void CallBackwardBinaryReduceBcast(
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) { && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) { && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig,
...@@ -262,7 +262,7 @@ void CallBackwardBinaryReduceBcast( ...@@ -262,7 +262,7 @@ void CallBackwardBinaryReduceBcast(
lhs_tgt, rhs_tgt, \ lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \ op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
BackwardBcastGData<ndim, IDX, dtype>* gdata); BackwardBcastGData<ndim, IDX, dtype>* gdata);
} // namespace kernel } // namespace kernel
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Binary reduce implementation on CPU. * \brief Binary reduce implementation on CPU.
*/ */
#include "../binary_reduce_impl.h" #include "../binary_reduce_impl.h"
#include "../csr_interface.h"
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
...@@ -13,7 +14,7 @@ namespace kernel { ...@@ -13,7 +14,7 @@ namespace kernel {
template void BinaryReduceImpl<kDLCPU>( template void BinaryReduceImpl<kDLCPU>(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -24,7 +25,7 @@ template void BinaryReduceBcastImpl<kDLCPU>( ...@@ -24,7 +25,7 @@ template void BinaryReduceBcastImpl<kDLCPU>(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -34,7 +35,7 @@ template void BinaryReduceBcastImpl<kDLCPU>( ...@@ -34,7 +35,7 @@ template void BinaryReduceBcastImpl<kDLCPU>(
template void BackwardBinaryReduceImpl<kDLCPU>( template void BackwardBinaryReduceImpl<kDLCPU>(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping, NDArray rhs_mapping, NDArray out_mapping, NDArray lhs_mapping, NDArray rhs_mapping, NDArray out_mapping,
NDArray lhs_data, NDArray rhs_data, NDArray out_data, NDArray lhs_data, NDArray rhs_data, NDArray out_data,
...@@ -45,7 +46,7 @@ template void BackwardBinaryReduceBcastImpl<kDLCPU>( ...@@ -45,7 +46,7 @@ template void BackwardBinaryReduceBcastImpl<kDLCPU>(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs_tgt, binary_op::Target rhs_tgt, binary_op::Target lhs_tgt, binary_op::Target rhs_tgt,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out, runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out,
......
...@@ -7,13 +7,13 @@ ...@@ -7,13 +7,13 @@
#define DGL_KERNEL_CPU_BINARY_REDUCE_IMPL_H_ #define DGL_KERNEL_CPU_BINARY_REDUCE_IMPL_H_
#include <minigun/minigun.h> #include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include <algorithm> #include <algorithm>
#include "../binary_reduce_impl_decl.h" #include "../binary_reduce_impl_decl.h"
#include "../utils.h" #include "../utils.h"
#include "./functor.h" #include "./functor.h"
#include "../csr_interface.h"
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -148,27 +148,27 @@ template <int XPU, typename Idx, typename DType, ...@@ -148,27 +148,27 @@ template <int XPU, typename Idx, typename DType,
typename LeftSelector, typename RightSelector, typename LeftSelector, typename RightSelector,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBinaryReduce(const minigun::advance::RuntimeConfig& rtcfg, void CallBinaryReduce(const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
GData<Idx, DType>* gdata) { GData<Idx, DType>* gdata) {
typedef cpu::FunctorsTempl<Idx, DType, LeftSelector, typedef cpu::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer> RightSelector, BinaryOp, Reducer>
Functors; Functors;
typedef cpu::BinaryReduce<Idx, DType, Functors> UDF; typedef cpu::BinaryReduce<Idx, DType, Functors> UDF;
// csr // csr
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to // If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge // replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, GData<Idx, DType>, UDF>( minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, GData<Idx, DType>, UDF>(
...@@ -181,27 +181,27 @@ template <int XPU, int NDim, typename Idx, typename DType, ...@@ -181,27 +181,27 @@ template <int XPU, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBinaryReduceBcast( void CallBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BcastGData<NDim, Idx, DType>* gdata) { BcastGData<NDim, Idx, DType>* gdata) {
typedef cpu::FunctorsTempl<Idx, DType, LeftSelector, typedef cpu::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer> RightSelector, BinaryOp, Reducer>
Functors; Functors;
typedef cpu::BinaryReduceBcast<NDim, Idx, DType, Functors> UDF; typedef cpu::BinaryReduceBcast<NDim, Idx, DType, Functors> UDF;
// csr // csr
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to // If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge // replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig,
...@@ -215,7 +215,7 @@ void CallBinaryReduceBcast( ...@@ -215,7 +215,7 @@ void CallBinaryReduceBcast(
template void CallBinaryReduce<XPU, IDX, \ template void CallBinaryReduce<XPU, IDX, \
dtype, lhs_tgt, rhs_tgt, op<dtype>, REDUCER<XPU, dtype>>( \ dtype, lhs_tgt, rhs_tgt, op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
GData<IDX, dtype>* gdata); GData<IDX, dtype>* gdata);
#define GEN_BCAST_DEFINE(ndim, dtype, lhs_tgt, rhs_tgt, op) \ #define GEN_BCAST_DEFINE(ndim, dtype, lhs_tgt, rhs_tgt, op) \
...@@ -223,7 +223,7 @@ void CallBinaryReduceBcast( ...@@ -223,7 +223,7 @@ void CallBinaryReduceBcast(
lhs_tgt, rhs_tgt, \ lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \ op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
BcastGData<ndim, IDX, dtype>* gdata); BcastGData<ndim, IDX, dtype>* gdata);
#define EVAL(F, ...) MSVC_EXPAND(F(__VA_ARGS__)) #define EVAL(F, ...) MSVC_EXPAND(F(__VA_ARGS__))
......
/*!
* Copyright (c) 2019 by Contributors
* \file kernel/csr_interface.h
* \brief Kernel common utilities
*/
#ifndef DGL_KERNEL_CSR_INTERFACE_H_
#define DGL_KERNEL_CSR_INTERFACE_H_
#include <dgl/array.h>
#include <dgl/runtime/c_runtime_api.h>
namespace dgl {
namespace kernel {
/*!
* \brief Wrapper class that unifies ImmutableGraph and Bipartite, which do
* not share a base class.
*
* \note This is an ugly temporary solution, and shall be removed after
* refactoring ImmutableGraph and Bipartite to use the same data structure.
*/
class CSRWrapper {
public:
virtual aten::CSRMatrix GetInCSRMatrix() const = 0;
virtual aten::CSRMatrix GetOutCSRMatrix() const = 0;
virtual DGLContext Context() const = 0;
virtual int NumBits() const = 0;
};
}; // namespace kernel
}; // namespace dgl
#endif // DGL_KERNEL_CSR_INTERFACE_H_
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
#define DGL_KERNEL_CUDA_BACKWARD_BINARY_REDUCE_IMPL_CUH_ #define DGL_KERNEL_CUDA_BACKWARD_BINARY_REDUCE_IMPL_CUH_
#include <minigun/minigun.h> #include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include "../binary_reduce_impl_decl.h" #include "../binary_reduce_impl_decl.h"
#include "../utils.h" #include "../utils.h"
#include "./functor.cuh" #include "./functor.cuh"
#include "../csr_interface.h"
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -171,14 +171,14 @@ template <int XPU, int Mode, typename Idx, typename DType, ...@@ -171,14 +171,14 @@ template <int XPU, int Mode, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduce( void CallBackwardBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardGData<Idx, DType>* gdata) { BackwardGData<Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src. // For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the // This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce // gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add. // competition of atomic add.
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cuda::BackwardFunctorsTempl<Idx, DType, typedef cuda::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type, typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type, typename SwitchSrcDst<RightSelector>::Type,
...@@ -189,15 +189,15 @@ void CallBackwardBinaryReduce( ...@@ -189,15 +189,15 @@ void CallBackwardBinaryReduce(
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) { && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) { && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, BackwardGData<Idx, DType>, UDF>( minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, BackwardGData<Idx, DType>, UDF>(
...@@ -212,7 +212,7 @@ void CallBackwardBinaryReduce( ...@@ -212,7 +212,7 @@ void CallBackwardBinaryReduce(
lhs_tgt, rhs_tgt, \ lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \ op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
BackwardGData<IDX, dtype>* gdata); BackwardGData<IDX, dtype>* gdata);
// Template implementation of BackwardBinaryReduce with broadcasting operator. // Template implementation of BackwardBinaryReduce with broadcasting operator.
...@@ -221,14 +221,14 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType, ...@@ -221,14 +221,14 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduceBcast( void CallBackwardBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardBcastGData<NDim, Idx, DType>* gdata) { BackwardBcastGData<NDim, Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src. // For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the // This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce // gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add. // competition of atomic add.
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cuda::BackwardFunctorsTempl<Idx, DType, typedef cuda::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type, typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type, typename SwitchSrcDst<RightSelector>::Type,
...@@ -239,15 +239,15 @@ void CallBackwardBinaryReduceBcast( ...@@ -239,15 +239,15 @@ void CallBackwardBinaryReduceBcast(
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) { && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) { && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig,
...@@ -263,7 +263,7 @@ void CallBackwardBinaryReduceBcast( ...@@ -263,7 +263,7 @@ void CallBackwardBinaryReduceBcast(
lhs_tgt, rhs_tgt, \ lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \ op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
BackwardBcastGData<ndim, IDX, dtype>* gdata); BackwardBcastGData<ndim, IDX, dtype>* gdata);
} // namespace kernel } // namespace kernel
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Binary reduce implementation on cuda. * \brief Binary reduce implementation on cuda.
*/ */
#include "../binary_reduce_impl.h" #include "../binary_reduce_impl.h"
#include "../csr_interface.h"
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
...@@ -13,7 +14,7 @@ namespace kernel { ...@@ -13,7 +14,7 @@ namespace kernel {
template void BinaryReduceImpl<kDLGPU>( template void BinaryReduceImpl<kDLGPU>(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -24,7 +25,7 @@ template void BinaryReduceBcastImpl<kDLGPU>( ...@@ -24,7 +25,7 @@ template void BinaryReduceBcastImpl<kDLGPU>(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data, runtime::NDArray out_data,
...@@ -34,7 +35,7 @@ template void BinaryReduceBcastImpl<kDLGPU>( ...@@ -34,7 +35,7 @@ template void BinaryReduceBcastImpl<kDLGPU>(
template void BackwardBinaryReduceImpl<kDLGPU>( template void BackwardBinaryReduceImpl<kDLGPU>(
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs, binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping, NDArray rhs_mapping, NDArray out_mapping, NDArray lhs_mapping, NDArray rhs_mapping, NDArray out_mapping,
NDArray lhs_data, NDArray rhs_data, NDArray out_data, NDArray lhs_data, NDArray rhs_data, NDArray out_data,
...@@ -45,7 +46,7 @@ template void BackwardBinaryReduceBcastImpl<kDLGPU>( ...@@ -45,7 +46,7 @@ template void BackwardBinaryReduceBcastImpl<kDLGPU>(
const BcastInfo& info, const BcastInfo& info,
const std::string& reducer, const std::string& reducer,
const std::string& op, const std::string& op,
const ImmutableGraph* graph, const CSRWrapper& graph,
binary_op::Target lhs_tgt, binary_op::Target rhs_tgt, binary_op::Target lhs_tgt, binary_op::Target rhs_tgt,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping, runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out, runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out,
......
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
#define DGL_KERNEL_CUDA_BINARY_REDUCE_IMPL_CUH_ #define DGL_KERNEL_CUDA_BINARY_REDUCE_IMPL_CUH_
#include <minigun/minigun.h> #include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include "../binary_reduce_impl_decl.h" #include "../binary_reduce_impl_decl.h"
#include "../utils.h" #include "../utils.h"
#include "./functor.cuh" #include "./functor.cuh"
#include "../csr_interface.h"
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -151,27 +151,27 @@ template <int XPU, typename Idx, typename DType, ...@@ -151,27 +151,27 @@ template <int XPU, typename Idx, typename DType,
typename LeftSelector, typename RightSelector, typename LeftSelector, typename RightSelector,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBinaryReduce(const minigun::advance::RuntimeConfig& rtcfg, void CallBinaryReduce(const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
GData<Idx, DType>* gdata) { GData<Idx, DType>* gdata) {
typedef cuda::FunctorsTempl<Idx, DType, LeftSelector, typedef cuda::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer> RightSelector, BinaryOp, Reducer>
Functors; Functors;
typedef cuda::BinaryReduce<Idx, DType, Functors> UDF; typedef cuda::BinaryReduce<Idx, DType, Functors> UDF;
// csr // csr
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to // If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge // replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, GData<Idx, DType>, UDF>( minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, GData<Idx, DType>, UDF>(
...@@ -184,27 +184,27 @@ template <int XPU, int NDim, typename Idx, typename DType, ...@@ -184,27 +184,27 @@ template <int XPU, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer> typename BinaryOp, typename Reducer>
void CallBinaryReduceBcast( void CallBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg, const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BcastGData<NDim, Idx, DType>* gdata) { BcastGData<NDim, Idx, DType>* gdata) {
typedef cuda::FunctorsTempl<Idx, DType, LeftSelector, typedef cuda::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer> RightSelector, BinaryOp, Reducer>
Functors; Functors;
typedef cuda::BinaryReduceBcast<NDim, Idx, DType, Functors> UDF; typedef cuda::BinaryReduceBcast<NDim, Idx, DType, Functors> UDF;
// csr // csr
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to // If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge // replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig,
...@@ -218,7 +218,7 @@ void CallBinaryReduceBcast( ...@@ -218,7 +218,7 @@ void CallBinaryReduceBcast(
template void CallBinaryReduce<XPU, IDX, \ template void CallBinaryReduce<XPU, IDX, \
dtype, lhs_tgt, rhs_tgt, op<dtype>, REDUCER<XPU, dtype>>( \ dtype, lhs_tgt, rhs_tgt, op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
GData<IDX, dtype>* gdata); GData<IDX, dtype>* gdata);
#define GEN_BCAST_DEFINE(ndim, dtype, lhs_tgt, rhs_tgt, op) \ #define GEN_BCAST_DEFINE(ndim, dtype, lhs_tgt, rhs_tgt, op) \
...@@ -226,7 +226,7 @@ void CallBinaryReduceBcast( ...@@ -226,7 +226,7 @@ void CallBinaryReduceBcast(
lhs_tgt, rhs_tgt, \ lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \ op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \ const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \ const CSRWrapper& graph, \
BcastGData<ndim, IDX, dtype>* gdata); BcastGData<ndim, IDX, dtype>* gdata);
#define EVAL(F, ...) MSVC_EXPAND(F(__VA_ARGS__)) #define EVAL(F, ...) MSVC_EXPAND(F(__VA_ARGS__))
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "./binary_reduce_impl.cuh" #include "./binary_reduce_impl.cuh"
#include "./backward_binary_reduce_impl.cuh" #include "./backward_binary_reduce_impl.cuh"
#include "../utils.h" #include "../utils.h"
#include "../csr_interface.h"
using minigun::advance::RuntimeConfig; using minigun::advance::RuntimeConfig;
using Csr = minigun::Csr<int32_t>; using Csr = minigun::Csr<int32_t>;
...@@ -153,7 +154,7 @@ void CusparseCsrmm2( ...@@ -153,7 +154,7 @@ void CusparseCsrmm2(
template <typename DType> template <typename DType>
void FallbackCallBinaryReduce( void FallbackCallBinaryReduce(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
GData<int32_t, DType>* gdata) { GData<int32_t, DType>* gdata) {
constexpr int XPU = kDLGPU; constexpr int XPU = kDLGPU;
typedef int32_t Idx; typedef int32_t Idx;
...@@ -166,20 +167,20 @@ void FallbackCallBinaryReduce( ...@@ -166,20 +167,20 @@ void FallbackCallBinaryReduce(
Functors; Functors;
typedef cuda::BinaryReduce<Idx, DType, Functors> UDF; typedef cuda::BinaryReduce<Idx, DType, Functors> UDF;
// csr // csr
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to // If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge // replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, GData<Idx, DType>, UDF>( minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, GData<Idx, DType>, UDF>(
...@@ -189,7 +190,7 @@ void FallbackCallBinaryReduce( ...@@ -189,7 +190,7 @@ void FallbackCallBinaryReduce(
template <typename DType> template <typename DType>
void FallbackCallBackwardBinaryReduce( void FallbackCallBackwardBinaryReduce(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardGData<int32_t, DType>* gdata) { BackwardGData<int32_t, DType>* gdata) {
constexpr int XPU = kDLGPU; constexpr int XPU = kDLGPU;
constexpr int Mode = binary_op::kGradLhs; constexpr int Mode = binary_op::kGradLhs;
...@@ -202,8 +203,8 @@ void FallbackCallBackwardBinaryReduce( ...@@ -202,8 +203,8 @@ void FallbackCallBackwardBinaryReduce(
// This benefits the most common src_op_edge or copy_src case, because the // This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce // gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add. // competition of atomic add.
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices()); minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cuda::BackwardFunctorsTempl<Idx, DType, typedef cuda::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type, typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type, typename SwitchSrcDst<RightSelector>::Type,
...@@ -214,15 +215,15 @@ void FallbackCallBackwardBinaryReduce( ...@@ -214,15 +215,15 @@ void FallbackCallBackwardBinaryReduce(
// data is correctly read/written. // data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) { && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (RightSelector::target == binary_op::kEdge if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) { && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
} }
if (OutSelector<Reducer>::Type::target == binary_op::kEdge if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) { && gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data); gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
} }
// TODO(minjie): allocator // TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, BackwardGData<Idx, DType>, UDF>( minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, BackwardGData<Idx, DType>, UDF>(
...@@ -235,14 +236,14 @@ template <> ...@@ -235,14 +236,14 @@ template <>
void CallBinaryReduce<kDLGPU, int32_t, float, SelectSrc, SelectNone, void CallBinaryReduce<kDLGPU, int32_t, float, SelectSrc, SelectNone,
BinaryUseLhs<float>, ReduceSum<kDLGPU, float>>( BinaryUseLhs<float>, ReduceSum<kDLGPU, float>>(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
GData<int32_t, float>* gdata) { GData<int32_t, float>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBinaryReduce<float>(rtcfg, graph, gdata); cuda::FallbackCallBinaryReduce<float>(rtcfg, graph, gdata);
} else { } else {
// cusparse use rev csr for csrmm // cusparse use rev csr for csrmm
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr->indptr(), incsr->indices()); Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length); gdata->out_size, gdata->x_length);
} }
...@@ -252,14 +253,14 @@ template <> ...@@ -252,14 +253,14 @@ template <>
void CallBinaryReduce<kDLGPU, int32_t, double, SelectSrc, SelectNone, void CallBinaryReduce<kDLGPU, int32_t, double, SelectSrc, SelectNone,
BinaryUseLhs<double>, ReduceSum<kDLGPU, double>>( BinaryUseLhs<double>, ReduceSum<kDLGPU, double>>(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
GData<int32_t, double>* gdata) { GData<int32_t, double>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBinaryReduce<double>(rtcfg, graph, gdata); cuda::FallbackCallBinaryReduce<double>(rtcfg, graph, gdata);
} else { } else {
// cusparse use rev csr for csrmm // cusparse use rev csr for csrmm
auto incsr = graph->GetInCSR(); auto incsr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr->indptr(), incsr->indices()); Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length); gdata->out_size, gdata->x_length);
} }
...@@ -272,13 +273,13 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, float, ...@@ -272,13 +273,13 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, float,
SelectSrc, SelectNone, SelectSrc, SelectNone,
BinaryUseLhs<float>, ReduceSum<kDLGPU, float>>( BinaryUseLhs<float>, ReduceSum<kDLGPU, float>>(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardGData<int32_t, float>* gdata) { BackwardGData<int32_t, float>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<float>(rtcfg, graph, gdata); cuda::FallbackCallBackwardBinaryReduce<float>(rtcfg, graph, gdata);
} else { } else {
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr->indptr(), outcsr->indices()); Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length); gdata->out_size, gdata->x_length);
} }
...@@ -289,13 +290,13 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, double, ...@@ -289,13 +290,13 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, double,
SelectSrc, SelectNone, SelectSrc, SelectNone,
BinaryUseLhs<double>, ReduceSum<kDLGPU, double>>( BinaryUseLhs<double>, ReduceSum<kDLGPU, double>>(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const ImmutableGraph* graph, const CSRWrapper& graph,
BackwardGData<int32_t, double>* gdata) { BackwardGData<int32_t, double>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<double>(rtcfg, graph, gdata); cuda::FallbackCallBackwardBinaryReduce<double>(rtcfg, graph, gdata);
} else { } else {
auto outcsr = graph->GetOutCSR(); auto outcsr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr->indptr(), outcsr->indices()); Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length); gdata->out_size, gdata->x_length);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment