Unverified Commit 788420df authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Refactor] Decoupling ImmutableGraph from Kernels (#749)

* csr interface

* csr wrapper for immutable graph

* lint

* silly fix

* docstring
parent 35bed2a9
......@@ -4,11 +4,13 @@
* \brief Binary reduce C APIs and definitions.
*/
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include "./binary_reduce.h"
#include "./common.h"
#include "./binary_reduce_impl_decl.h"
#include "./utils.h"
#include "../c_api_common.h"
#include "./csr_interface.h"
using namespace dgl::runtime;
......@@ -201,6 +203,31 @@ inline bool NeedSwitchOrder(const std::string& op,
&& lhs > rhs;
}
class ImmutableGraphCSRWrapper : public CSRWrapper {
public:
explicit ImmutableGraphCSRWrapper(const ImmutableGraph* graph) :
gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override {
return gptr_->GetInCSR()->ToCSRMatrix();
}
aten::CSRMatrix GetOutCSRMatrix() const override {
return gptr_->GetOutCSR()->ToCSRMatrix();
}
DGLContext Context() const override {
return gptr_->Context();
}
int NumBits() const override {
return gptr_->NumBits();
}
private:
const ImmutableGraph* gptr_;
};
} // namespace
......@@ -226,18 +253,18 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelInferBinaryFeatureShape")
void BinaryOpReduce(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_data, NDArray rhs_data,
NDArray out_data,
NDArray lhs_mapping, NDArray rhs_mapping,
NDArray out_mapping) {
const auto& ctx = graph->Context();
const auto& ctx = graph.Context();
// sanity check
CheckCtx(ctx,
{lhs_data, rhs_data, out_data, lhs_mapping, rhs_mapping, out_mapping},
{"lhs_data", "rhs_data", "out_data", "lhs_mapping", "rhs_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(),
CheckIdArray(graph.NumBits(),
{lhs_mapping, rhs_mapping, out_mapping},
{"lhs_mapping", "rhs_mapping", "out_mapping"});
// Switch order for commutative operation
......@@ -282,7 +309,8 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
BinaryOpReduce(reducer, op, igptr.get(),
ImmutableGraphCSRWrapper wrapper(igptr.get());
BinaryOpReduce(reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping);
......@@ -291,7 +319,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
void BackwardLhsBinaryOpReduce(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping,
NDArray rhs_mapping,
......@@ -301,14 +329,14 @@ void BackwardLhsBinaryOpReduce(
NDArray out_data,
NDArray grad_out_data,
NDArray grad_lhs_data) {
const auto& ctx = graph->Context();
const auto& ctx = graph.Context();
// sanity check
CheckCtx(ctx,
{lhs_data, rhs_data, out_data, grad_out_data, grad_lhs_data,
lhs_mapping, rhs_mapping, out_mapping},
{"lhs_data", "rhs_data", "out_data", "grad_out_data", "grad_lhs_data",
"lhs_mapping", "rhs_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(),
CheckIdArray(graph.NumBits(),
{lhs_mapping, rhs_mapping, out_mapping},
{"lhs_mapping", "rhs_mapping", "out_mapping"});
// Switch order for commutative operation
......@@ -356,8 +384,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardLhsBinaryOpReduce(
reducer, op, igptr.get(),
reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
......@@ -367,7 +396,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
void BackwardRhsBinaryOpReduce(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping,
NDArray rhs_mapping,
......@@ -377,14 +406,14 @@ void BackwardRhsBinaryOpReduce(
NDArray out_data,
NDArray grad_out_data,
NDArray grad_rhs_data) {
const auto& ctx = graph->Context();
const auto& ctx = graph.Context();
// sanity check
CheckCtx(ctx,
{lhs_data, rhs_data, out_data, grad_out_data, grad_rhs_data,
lhs_mapping, rhs_mapping, out_mapping},
{"lhs_data", "rhs_data", "out_data", "grad_out_data", "grad_rhs_data",
"lhs_mapping", "rhs_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(),
CheckIdArray(graph.NumBits(),
{lhs_mapping, rhs_mapping, out_mapping},
{"lhs_mapping", "rhs_mapping", "out_mapping"});
if (NeedSwitchOrder(op, lhs, rhs)) {
......@@ -431,8 +460,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardRhsBinaryOpReduce(
reducer, op, igptr.get(),
reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
......@@ -441,16 +471,16 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
void CopyReduce(
const std::string& reducer,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target target,
NDArray in_data, NDArray out_data,
NDArray in_mapping, NDArray out_mapping) {
const auto& ctx = graph->Context();
const auto& ctx = graph.Context();
// sanity check
CheckCtx(ctx,
{in_data, out_data, in_mapping, out_mapping},
{"in_data", "out_data", "in_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(),
CheckIdArray(graph.NumBits(),
{in_mapping, out_mapping},
{"in_mapping", "out_mapping"});
DGL_XPU_SWITCH(ctx.device_type, BinaryReduceImpl,
......@@ -472,7 +502,8 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
CopyReduce(reducer, igptr.get(),
ImmutableGraphCSRWrapper wrapper(igptr.get());
CopyReduce(reducer, wrapper,
static_cast<binary_op::Target>(target),
in_data, out_data,
in_mapping, out_mapping);
......@@ -480,7 +511,7 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
void BackwardCopyReduce(
const std::string& reducer,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target target,
NDArray in_mapping,
NDArray out_mapping,
......@@ -488,12 +519,12 @@ void BackwardCopyReduce(
NDArray out_data,
NDArray grad_out_data,
NDArray grad_in_data) {
const auto& ctx = graph->Context();
const auto& ctx = graph.Context();
// sanity check
CheckCtx(ctx,
{in_data, out_data, grad_out_data, grad_in_data, in_mapping, out_mapping},
{"in_data", "out_data", "grad_out_data", "grad_in_data", "in_mapping", "out_mapping"});
CheckIdArray(graph->NumBits(),
CheckIdArray(graph.NumBits(),
{in_mapping, out_mapping},
{"in_mapping", "out_mapping"});
if (!utils::IsNoneArray(out_mapping)) {
......@@ -522,8 +553,9 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(igptr) << "Invalid graph object argument. Must be an immutable graph.";
ImmutableGraphCSRWrapper wrapper(igptr.get());
BackwardCopyReduce(
reducer, igptr.get(), static_cast<binary_op::Target>(target),
reducer, wrapper, static_cast<binary_op::Target>(target),
in_mapping, out_mapping,
in_data, out_data, grad_out_data,
grad_in_data);
......
......@@ -7,12 +7,12 @@
#define DGL_KERNEL_BINARY_REDUCE_H_
#include <dgl/runtime/ndarray.h>
#include <dgl/immutable_graph.h>
#include <vector>
#include <string>
#include "./binary_reduce_common.h"
#include "./csr_interface.h"
namespace dgl {
namespace kernel {
......@@ -83,7 +83,7 @@ std::vector<int64_t> InferBinaryFeatureShape(
void BinaryOpReduce(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -126,7 +126,7 @@ void BinaryOpReduce(
void BackwardLhsBinaryOpReduce(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping,
runtime::NDArray rhs_mapping,
......@@ -173,7 +173,7 @@ void BackwardLhsBinaryOpReduce(
void BackwardRhsBinaryOpReduce(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping,
runtime::NDArray rhs_mapping,
......@@ -213,7 +213,7 @@ void BackwardRhsBinaryOpReduce(
*/
void CopyReduce(
const std::string& reducer,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target target,
runtime::NDArray in_data, runtime::NDArray out_data,
runtime::NDArray in_mapping, runtime::NDArray out_mapping);
......@@ -236,7 +236,7 @@ void CopyReduce(
*/
void BackwardCopyReduce(
const std::string& reducer,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target target,
runtime::NDArray in_mapping,
runtime::NDArray out_mapping,
......
......@@ -8,7 +8,6 @@
#include <minigun/minigun.h>
#include <dgl/runtime/device_api.h>
#include <dgl/immutable_graph.h>
#include <algorithm>
#include <string>
......@@ -18,6 +17,7 @@
#endif
#include "./binary_reduce.h"
#include "./binary_reduce_impl_decl.h"
#include "./csr_interface.h"
#include "./utils.h"
namespace dgl {
......@@ -58,7 +58,7 @@ template <int XPU>
void BinaryReduceImpl(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -88,7 +88,7 @@ void BinaryReduceImpl(
LOG(FATAL) << "reduce mean is not supported.";
}
const DLDataType& dtype = out_data->dtype;
const auto bits = graph->NumBits();
const auto bits = graph.NumBits();
DGL_DTYPE_SWITCH(dtype, DType, {
DGL_IDX_TYPE_SWITCH(bits, Idx, {
REDUCER_SWITCH(reducer, XPU, DType, Reducer, {
......@@ -151,7 +151,7 @@ template <int XPU>
void BackwardBinaryReduceImpl(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
......@@ -181,7 +181,7 @@ void BackwardBinaryReduceImpl(
const DLDataType& dtype = out_data->dtype;
const bool req_lhs = !utils::IsNoneArray(grad_lhs_data);
const bool req_rhs = !utils::IsNoneArray(grad_rhs_data);
const auto bits = graph->NumBits();
const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported.";
......@@ -250,7 +250,7 @@ void BinaryReduceBcastImpl(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs,
binary_op::Target rhs,
runtime::NDArray lhs_data,
......@@ -279,7 +279,7 @@ void BinaryReduceBcastImpl(
const DLDataType& dtype = out_data->dtype;
const int bcast_ndim = info.out_shape.size();
const auto bits = graph->NumBits();
const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported.";
......@@ -359,7 +359,7 @@ void BackwardBinaryReduceBcastImpl(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs_tgt, binary_op::Target rhs_tgt,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out,
......@@ -386,7 +386,7 @@ void BackwardBinaryReduceBcastImpl(
const int bcast_ndim = info.out_shape.size();
const bool req_lhs = !utils::IsNoneArray(grad_lhs);
const bool req_rhs = !utils::IsNoneArray(grad_rhs);
const auto bits = graph->NumBits();
const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) {
// TODO(minjie): divide
LOG(FATAL) << "reduce mean is not supported.";
......
......@@ -11,6 +11,7 @@
#include <string>
#include "./binary_reduce_common.h"
#include "./csr_interface.h"
namespace minigun {
namespace advance {
......@@ -21,9 +22,6 @@ struct RuntimeConfig;
namespace dgl {
// forward declaration
class ImmutableGraph;
namespace kernel {
// forward declaration
......@@ -81,7 +79,7 @@ template <int XPU, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
GData<Idx, DType>* gdata);
/*!
......@@ -107,7 +105,7 @@ template <int XPU>
void BinaryReduceImpl(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping);
......@@ -168,7 +166,7 @@ template <int XPU, int Mode, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardGData<Idx, DType>* gdata);
/*!
......@@ -196,7 +194,7 @@ template <int XPU>
void BackwardBinaryReduceImpl(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
......@@ -269,7 +267,7 @@ template <int XPU, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BcastGData<NDim, Idx, DType>* gdata);
/*!
......@@ -296,7 +294,7 @@ void BinaryReduceBcastImpl(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -370,7 +368,7 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardBcastGData<NDim, Idx, DType>* gdata);
/*!
......@@ -399,7 +397,7 @@ void BackwardBinaryReduceBcastImpl(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs_data, runtime::NDArray rhs_data, runtime::NDArray out_data,
......
......@@ -7,11 +7,11 @@
#define DGL_KERNEL_CPU_BACKWARD_BINARY_REDUCE_IMPL_H_
#include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include "../binary_reduce_impl_decl.h"
#include "../utils.h"
#include "./functor.h"
#include "../csr_interface.h"
namespace dgl {
namespace kernel {
......@@ -170,14 +170,14 @@ template <int XPU, int Mode, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardGData<Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add.
auto incsr = graph->GetInCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cpu::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type,
......@@ -188,15 +188,15 @@ void CallBackwardBinaryReduce(
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, BackwardGData<Idx, DType>, UDF>(
......@@ -211,7 +211,7 @@ void CallBackwardBinaryReduce(
lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
BackwardGData<IDX, dtype>* gdata);
// Template implementation of BackwardBinaryReduce with broadcasting operator.
......@@ -220,14 +220,14 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardBcastGData<NDim, Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add.
auto incsr = graph->GetInCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cpu::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type,
......@@ -238,15 +238,15 @@ void CallBackwardBinaryReduceBcast(
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig,
......@@ -262,7 +262,7 @@ void CallBackwardBinaryReduceBcast(
lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
BackwardBcastGData<ndim, IDX, dtype>* gdata);
} // namespace kernel
......
......@@ -4,6 +4,7 @@
* \brief Binary reduce implementation on CPU.
*/
#include "../binary_reduce_impl.h"
#include "../csr_interface.h"
using dgl::runtime::NDArray;
......@@ -13,7 +14,7 @@ namespace kernel {
template void BinaryReduceImpl<kDLCPU>(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -24,7 +25,7 @@ template void BinaryReduceBcastImpl<kDLCPU>(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -34,7 +35,7 @@ template void BinaryReduceBcastImpl<kDLCPU>(
template void BackwardBinaryReduceImpl<kDLCPU>(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping, NDArray rhs_mapping, NDArray out_mapping,
NDArray lhs_data, NDArray rhs_data, NDArray out_data,
......@@ -45,7 +46,7 @@ template void BackwardBinaryReduceBcastImpl<kDLCPU>(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs_tgt, binary_op::Target rhs_tgt,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out,
......
......@@ -7,13 +7,13 @@
#define DGL_KERNEL_CPU_BINARY_REDUCE_IMPL_H_
#include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include <algorithm>
#include "../binary_reduce_impl_decl.h"
#include "../utils.h"
#include "./functor.h"
#include "../csr_interface.h"
namespace dgl {
namespace kernel {
......@@ -148,27 +148,27 @@ template <int XPU, typename Idx, typename DType,
typename LeftSelector, typename RightSelector,
typename BinaryOp, typename Reducer>
void CallBinaryReduce(const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
GData<Idx, DType>* gdata) {
typedef cpu::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer>
Functors;
typedef cpu::BinaryReduce<Idx, DType, Functors> UDF;
// csr
auto outcsr = graph->GetOutCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig, GData<Idx, DType>, UDF>(
......@@ -181,27 +181,27 @@ template <int XPU, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BcastGData<NDim, Idx, DType>* gdata) {
typedef cpu::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer>
Functors;
typedef cpu::BinaryReduceBcast<NDim, Idx, DType, Functors> UDF;
// csr
auto outcsr = graph->GetOutCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cpu::AdvanceConfig,
......@@ -215,7 +215,7 @@ void CallBinaryReduceBcast(
template void CallBinaryReduce<XPU, IDX, \
dtype, lhs_tgt, rhs_tgt, op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
GData<IDX, dtype>* gdata);
#define GEN_BCAST_DEFINE(ndim, dtype, lhs_tgt, rhs_tgt, op) \
......@@ -223,7 +223,7 @@ void CallBinaryReduceBcast(
lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
BcastGData<ndim, IDX, dtype>* gdata);
#define EVAL(F, ...) MSVC_EXPAND(F(__VA_ARGS__))
......
/*!
* Copyright (c) 2019 by Contributors
* \file kernel/csr_interface.h
* \brief Kernel common utilities
*/
#ifndef DGL_KERNEL_CSR_INTERFACE_H_
#define DGL_KERNEL_CSR_INTERFACE_H_
#include <dgl/array.h>
#include <dgl/runtime/c_runtime_api.h>
namespace dgl {
namespace kernel {
/*!
* \brief Wrapper class that unifies ImmutableGraph and Bipartite, which do
* not share a base class.
*
* \note This is an ugly temporary solution, and shall be removed after
* refactoring ImmutableGraph and Bipartite to use the same data structure.
*/
class CSRWrapper {
public:
virtual aten::CSRMatrix GetInCSRMatrix() const = 0;
virtual aten::CSRMatrix GetOutCSRMatrix() const = 0;
virtual DGLContext Context() const = 0;
virtual int NumBits() const = 0;
};
}; // namespace kernel
}; // namespace dgl
#endif // DGL_KERNEL_CSR_INTERFACE_H_
......@@ -7,11 +7,11 @@
#define DGL_KERNEL_CUDA_BACKWARD_BINARY_REDUCE_IMPL_CUH_
#include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include "../binary_reduce_impl_decl.h"
#include "../utils.h"
#include "./functor.cuh"
#include "../csr_interface.h"
namespace dgl {
namespace kernel {
......@@ -171,14 +171,14 @@ template <int XPU, int Mode, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduce(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardGData<Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add.
auto incsr = graph->GetInCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cuda::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type,
......@@ -189,15 +189,15 @@ void CallBackwardBinaryReduce(
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, BackwardGData<Idx, DType>, UDF>(
......@@ -212,7 +212,7 @@ void CallBackwardBinaryReduce(
lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
BackwardGData<IDX, dtype>* gdata);
// Template implementation of BackwardBinaryReduce with broadcasting operator.
......@@ -221,14 +221,14 @@ template <int XPU, int Mode, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBackwardBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardBcastGData<NDim, Idx, DType>* gdata) {
// For backward computation, we use reverse csr and switch dst and src.
// This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add.
auto incsr = graph->GetInCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cuda::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type,
......@@ -239,15 +239,15 @@ void CallBackwardBinaryReduceBcast(
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig,
......@@ -263,7 +263,7 @@ void CallBackwardBinaryReduceBcast(
lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
BackwardBcastGData<ndim, IDX, dtype>* gdata);
} // namespace kernel
......
......@@ -4,6 +4,7 @@
* \brief Binary reduce implementation on cuda.
*/
#include "../binary_reduce_impl.h"
#include "../csr_interface.h"
using dgl::runtime::NDArray;
......@@ -13,7 +14,7 @@ namespace kernel {
template void BinaryReduceImpl<kDLGPU>(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -24,7 +25,7 @@ template void BinaryReduceBcastImpl<kDLGPU>(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
runtime::NDArray lhs_data, runtime::NDArray rhs_data,
runtime::NDArray out_data,
......@@ -34,7 +35,7 @@ template void BinaryReduceBcastImpl<kDLGPU>(
template void BackwardBinaryReduceImpl<kDLGPU>(
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs, binary_op::Target rhs,
NDArray lhs_mapping, NDArray rhs_mapping, NDArray out_mapping,
NDArray lhs_data, NDArray rhs_data, NDArray out_data,
......@@ -45,7 +46,7 @@ template void BackwardBinaryReduceBcastImpl<kDLGPU>(
const BcastInfo& info,
const std::string& reducer,
const std::string& op,
const ImmutableGraph* graph,
const CSRWrapper& graph,
binary_op::Target lhs_tgt, binary_op::Target rhs_tgt,
runtime::NDArray lhs_mapping, runtime::NDArray rhs_mapping, runtime::NDArray out_mapping,
runtime::NDArray lhs, runtime::NDArray rhs, runtime::NDArray out, runtime::NDArray grad_out,
......
......@@ -7,11 +7,11 @@
#define DGL_KERNEL_CUDA_BINARY_REDUCE_IMPL_CUH_
#include <minigun/minigun.h>
#include <dgl/immutable_graph.h>
#include "../binary_reduce_impl_decl.h"
#include "../utils.h"
#include "./functor.cuh"
#include "../csr_interface.h"
namespace dgl {
namespace kernel {
......@@ -151,27 +151,27 @@ template <int XPU, typename Idx, typename DType,
typename LeftSelector, typename RightSelector,
typename BinaryOp, typename Reducer>
void CallBinaryReduce(const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
GData<Idx, DType>* gdata) {
typedef cuda::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer>
Functors;
typedef cuda::BinaryReduce<Idx, DType, Functors> UDF;
// csr
auto outcsr = graph->GetOutCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, GData<Idx, DType>, UDF>(
......@@ -184,27 +184,27 @@ template <int XPU, int NDim, typename Idx, typename DType,
typename BinaryOp, typename Reducer>
void CallBinaryReduceBcast(
const minigun::advance::RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BcastGData<NDim, Idx, DType>* gdata) {
typedef cuda::FunctorsTempl<Idx, DType, LeftSelector,
RightSelector, BinaryOp, Reducer>
Functors;
typedef cuda::BinaryReduceBcast<NDim, Idx, DType, Functors> UDF;
// csr
auto outcsr = graph->GetOutCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig,
......@@ -218,7 +218,7 @@ void CallBinaryReduceBcast(
template void CallBinaryReduce<XPU, IDX, \
dtype, lhs_tgt, rhs_tgt, op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
GData<IDX, dtype>* gdata);
#define GEN_BCAST_DEFINE(ndim, dtype, lhs_tgt, rhs_tgt, op) \
......@@ -226,7 +226,7 @@ void CallBinaryReduceBcast(
lhs_tgt, rhs_tgt, \
op<dtype>, REDUCER<XPU, dtype>>( \
const minigun::advance::RuntimeConfig& rtcfg, \
const ImmutableGraph* graph, \
const CSRWrapper& graph, \
BcastGData<ndim, IDX, dtype>* gdata);
#define EVAL(F, ...) MSVC_EXPAND(F(__VA_ARGS__))
......
......@@ -9,6 +9,7 @@
#include "./binary_reduce_impl.cuh"
#include "./backward_binary_reduce_impl.cuh"
#include "../utils.h"
#include "../csr_interface.h"
using minigun::advance::RuntimeConfig;
using Csr = minigun::Csr<int32_t>;
......@@ -153,7 +154,7 @@ void CusparseCsrmm2(
template <typename DType>
void FallbackCallBinaryReduce(
const RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
GData<int32_t, DType>* gdata) {
constexpr int XPU = kDLGPU;
typedef int32_t Idx;
......@@ -166,20 +167,20 @@ void FallbackCallBinaryReduce(
Functors;
typedef cuda::BinaryReduce<Idx, DType, Functors> UDF;
// csr
auto outcsr = graph->GetOutCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(outcsr.indptr, outcsr.indices);
// If the user-given mapping is none and the target is edge data, we need to
// replace the mapping by the edge ids in the csr graph so that the edge
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(outcsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(outcsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(outcsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, GData<Idx, DType>, UDF>(
......@@ -189,7 +190,7 @@ void FallbackCallBinaryReduce(
template <typename DType>
void FallbackCallBackwardBinaryReduce(
const RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardGData<int32_t, DType>* gdata) {
constexpr int XPU = kDLGPU;
constexpr int Mode = binary_op::kGradLhs;
......@@ -202,8 +203,8 @@ void FallbackCallBackwardBinaryReduce(
// This benefits the most common src_op_edge or copy_src case, because the
// gradients of src are now aggregated into destination buffer to reduce
// competition of atomic add.
auto incsr = graph->GetInCSR();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
minigun::Csr<Idx> csr = utils::CreateCsr<Idx>(incsr.indptr, incsr.indices);
typedef cuda::BackwardFunctorsTempl<Idx, DType,
typename SwitchSrcDst<LeftSelector>::Type,
typename SwitchSrcDst<RightSelector>::Type,
......@@ -214,15 +215,15 @@ void FallbackCallBackwardBinaryReduce(
// data is correctly read/written.
if (LeftSelector::target == binary_op::kEdge
&& gdata->lhs_mapping == nullptr) {
gdata->lhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->lhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (RightSelector::target == binary_op::kEdge
&& gdata->rhs_mapping == nullptr) {
gdata->rhs_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->rhs_mapping = static_cast<Idx*>(incsr.data->data);
}
if (OutSelector<Reducer>::Type::target == binary_op::kEdge
&& gdata->out_mapping == nullptr) {
gdata->out_mapping = static_cast<Idx*>(incsr->edge_ids()->data);
gdata->out_mapping = static_cast<Idx*>(incsr.data->data);
}
// TODO(minjie): allocator
minigun::advance::Advance<XPU, Idx, cuda::AdvanceConfig, BackwardGData<Idx, DType>, UDF>(
......@@ -235,14 +236,14 @@ template <>
void CallBinaryReduce<kDLGPU, int32_t, float, SelectSrc, SelectNone,
BinaryUseLhs<float>, ReduceSum<kDLGPU, float>>(
const RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
GData<int32_t, float>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBinaryReduce<float>(rtcfg, graph, gdata);
} else {
// cusparse use rev csr for csrmm
auto incsr = graph->GetInCSR();
Csr csr = utils::CreateCsr<int32_t>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length);
}
......@@ -252,14 +253,14 @@ template <>
void CallBinaryReduce<kDLGPU, int32_t, double, SelectSrc, SelectNone,
BinaryUseLhs<double>, ReduceSum<kDLGPU, double>>(
const RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
GData<int32_t, double>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBinaryReduce<double>(rtcfg, graph, gdata);
} else {
// cusparse use rev csr for csrmm
auto incsr = graph->GetInCSR();
Csr csr = utils::CreateCsr<int32_t>(incsr->indptr(), incsr->indices());
auto incsr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length);
}
......@@ -272,13 +273,13 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, float,
SelectSrc, SelectNone,
BinaryUseLhs<float>, ReduceSum<kDLGPU, float>>(
const RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardGData<int32_t, float>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<float>(rtcfg, graph, gdata);
} else {
auto outcsr = graph->GetOutCSR();
Csr csr = utils::CreateCsr<int32_t>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length);
}
......@@ -289,13 +290,13 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, double,
SelectSrc, SelectNone,
BinaryUseLhs<double>, ReduceSum<kDLGPU, double>>(
const RuntimeConfig& rtcfg,
const ImmutableGraph* graph,
const CSRWrapper& graph,
BackwardGData<int32_t, double>* gdata) {
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<double>(rtcfg, graph, gdata);
} else {
auto outcsr = graph->GetOutCSR();
Csr csr = utils::CreateCsr<int32_t>(outcsr->indptr(), outcsr->indices());
auto outcsr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment