/*! * Copyright (c) 2019 by Contributors * \file kernel/cuda/backward_binary_reduce_impl.h * \brief Minigun CPU UDFs for bacward binary reduce */ #ifndef DGL_KERNEL_CPU_BACKWARD_BINARY_REDUCE_IMPL_H_ #define DGL_KERNEL_CPU_BACKWARD_BINARY_REDUCE_IMPL_H_ #include #include "../binary_reduce_impl_decl.h" #include "../utils.h" #include "./functor.h" #include "../csr_interface.h" namespace dgl { namespace kernel { namespace cpu { // Minigun UDF to compute backward binary reduce. template struct BackwardBinaryReduce { static inline bool CondEdge( Idx src, Idx dst, Idx eid, BackwardGData* gdata) { return true; } static inline void ApplyEdge( Idx src, Idx dst, Idx eid, BackwardGData* gdata) { const int64_t D = gdata->x_length; Idx lid = Functors::SelectLeft(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst); if (gdata->lhs_mapping) { lid = Functors::GetId(lid, gdata->lhs_mapping); } if (gdata->rhs_mapping) { rid = Functors::GetId(rid, gdata->rhs_mapping); } if (gdata->out_mapping) { oid = Functors::GetId(oid, gdata->out_mapping); } DType* lhsoff = gdata->lhs_data + lid * D; DType* rhsoff = gdata->rhs_data + rid * D; DType* outoff = gdata->out_data + oid * D; DType* gradlhsoff = gdata->grad_lhs_data + lid * D; DType* gradrhsoff = gdata->grad_rhs_data + rid * D; DType* gradoutoff = gdata->grad_out_data + oid * D; for (int64_t tx = 0; tx < D; ++tx) { DType lhs = Functors::Read(lhsoff + tx); DType rhs = Functors::Read(rhsoff + tx); DType out = Functors::Read(outoff + tx); DType grad_out = Functors::Read(gradoutoff + tx); DType e = Functors::Op(lhs, rhs); DType grad_e = grad_out * Functors::BackwardWrite(e, out); if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) { DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e); #pragma omp atomic gradlhsoff[tx] += grad_lhs; } if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) { DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e); #pragma omp atomic gradrhsoff[tx] += grad_rhs; } } } }; // Minigun UDF to compute backward binary reduce with broadcasting. template struct BackwardBinaryReduceBcast { static inline bool CondEdge( Idx src, Idx dst, Idx eid, BackwardBcastGData* gdata) { return true; } static inline void ApplyEdge( Idx src, Idx dst, Idx eid, BackwardBcastGData* gdata) { Idx lid = Functors::SelectLeft(src, eid, dst); Idx rid = Functors::SelectRight(src, eid, dst); Idx oid = Functors::SelectOut(src, eid, dst); if (gdata->lhs_mapping) { lid = Functors::GetId(lid, gdata->lhs_mapping); } if (gdata->rhs_mapping) { rid = Functors::GetId(rid, gdata->rhs_mapping); } if (gdata->out_mapping) { oid = Functors::GetId(oid, gdata->out_mapping); } DType* lhsoff = gdata->lhs_data + lid * gdata->lhs_len; DType* rhsoff = gdata->rhs_data + rid * gdata->rhs_len; DType* outoff = gdata->out_data + oid * gdata->out_len; DType* gradlhsoff = gdata->grad_lhs_data + lid * gdata->out_len; DType* gradrhsoff = gdata->grad_rhs_data + rid * gdata->out_len; DType* gradoutoff = gdata->grad_out_data + oid * gdata->out_len; int64_t tmp[NDim]; // store unraveled idx. for (int64_t tx = 0; tx < gdata->out_len; ++tx) { Unravel(tx, gdata->ndim, gdata->out_shape, gdata->out_stride, tmp); DType lhs = Functors::Read(lhsoff + Ravel(tmp, gdata->ndim, gdata->lhs_shape, gdata->lhs_stride)); DType rhs = Functors::Read(rhsoff + Ravel(tmp, gdata->ndim, gdata->rhs_shape, gdata->rhs_stride)); DType out = Functors::Read(outoff + tx); DType grad_out = Functors::Read(gradoutoff + tx); DType e = Functors::Op(lhs, rhs); DType grad_e = grad_out * Functors::BackwardWrite(e, out); if (Mode == binary_op::kGradLhs || Mode == binary_op::kGradBoth) { DType grad_lhs = grad_e * Functors::BackwardOpLhs(lhs, rhs, e); #pragma omp atomic gradlhsoff[tx] += grad_lhs; } if (Mode == binary_op::kGradRhs || Mode == binary_op::kGradBoth) { DType grad_rhs = grad_e * Functors::BackwardOpRhs(lhs, rhs, e); #pragma omp atomic gradrhsoff[tx] += grad_rhs; } } } }; // Auxiliary template used in UDF. template struct BackwardFunctorsTempl { static inline Idx SelectOut( Idx src, Idx edge, Idx dst) { typedef typename OutSelector::Type OutTarget; return SwitchSrcDst::Type::Call(src, edge, dst); } static inline Idx SelectLeft( Idx src, Idx edge, Idx dst) { return LeftSelector::Call(src, edge, dst); } static inline Idx SelectRight( Idx src, Idx edge, Idx dst) { return RightSelector::Call(src, edge, dst); } static inline DType Op(DType lhs, DType rhs) { return BinaryOp::Call(lhs, rhs); } static inline DType Read(DType* addr) { return *addr; } static inline void Write(DType* addr, DType val) { Reducer::Call(addr, val); } static inline Idx GetId(Idx id, Idx* id_map) { return *(id_map + id); } static inline DType BackwardWrite(DType val, DType accum) { return Reducer::BackwardCall(val, accum); } static inline DType BackwardOpLhs(DType lhs, DType rhs, DType out) { return BinaryOp::BackwardLhs(lhs, rhs, out); } static inline DType BackwardOpRhs(DType lhs, DType rhs, DType out) { return BinaryOp::BackwardRhs(lhs, rhs, out); } }; typedef minigun::advance::Config AdvanceConfig; } // namespace cpu // Template implementation of BackwardBinaryReduce operator. template void CallBackwardBinaryReduce( const minigun::advance::RuntimeConfig& rtcfg, const CSRWrapper& graph, BackwardGData* gdata) { // For backward computation, we use reverse csr and switch dst and src. // This benefits the most common src_op_edge or copy_src case, because the // gradients of src are now aggregated into destination buffer to reduce // competition of atomic add. auto incsr = graph.GetInCSRMatrix(); minigun::Csr csr = utils::CreateCsr(incsr.indptr, incsr.indices); typedef cpu::BackwardFunctorsTempl::Type, typename SwitchSrcDst::Type, BinaryOp, Reducer> Functors; typedef cpu::BackwardBinaryReduce UDF; // If the user-given mapping is none and the target is edge data, we need to // replace the mapping by the edge ids in the csr graph so that the edge // data is correctly read/written. if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { gdata->lhs_mapping = static_cast(incsr.data->data); } if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { gdata->rhs_mapping = static_cast(incsr.data->data); } if (OutSelector::Type::target == binary_op::kEdge && gdata->out_mapping == nullptr) { gdata->out_mapping = static_cast(incsr.data->data); } // TODO(minjie): allocator minigun::advance::Advance, UDF>( rtcfg, csr, gdata, minigun::IntArray1D()); } // Following macro is used to generate explicit-specialization of the template // operator. #define GEN_BACKWARD_DEFINE(mode, dtype, lhs_tgt, rhs_tgt, op) \ template void CallBackwardBinaryReduce, REDUCER>( \ const minigun::advance::RuntimeConfig& rtcfg, \ const CSRWrapper& graph, \ BackwardGData* gdata); // Template implementation of BackwardBinaryReduce with broadcasting operator. template void CallBackwardBinaryReduceBcast( const minigun::advance::RuntimeConfig& rtcfg, const CSRWrapper& graph, BackwardBcastGData* gdata) { // For backward computation, we use reverse csr and switch dst and src. // This benefits the most common src_op_edge or copy_src case, because the // gradients of src are now aggregated into destination buffer to reduce // competition of atomic add. auto incsr = graph.GetInCSRMatrix(); minigun::Csr csr = utils::CreateCsr(incsr.indptr, incsr.indices); typedef cpu::BackwardFunctorsTempl::Type, typename SwitchSrcDst::Type, BinaryOp, Reducer> Functors; typedef cpu::BackwardBinaryReduceBcast UDF; // If the user-given mapping is none and the target is edge data, we need to // replace the mapping by the edge ids in the csr graph so that the edge // data is correctly read/written. if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { gdata->lhs_mapping = static_cast(incsr.data->data); } if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { gdata->rhs_mapping = static_cast(incsr.data->data); } if (OutSelector::Type::target == binary_op::kEdge && gdata->out_mapping == nullptr) { gdata->out_mapping = static_cast(incsr.data->data); } // TODO(minjie): allocator minigun::advance::Advance, UDF>( rtcfg, csr, gdata, minigun::IntArray1D()); } // Following macro is used to generate explicit-specialization of the template // operator. #define GEN_BACKWARD_BCAST_DEFINE(mode, ndim, dtype, lhs_tgt, rhs_tgt, op) \ template void CallBackwardBinaryReduceBcast, REDUCER>( \ const minigun::advance::RuntimeConfig& rtcfg, \ const CSRWrapper& graph, \ BackwardBcastGData* gdata); } // namespace kernel } // namespace dgl #endif // DGL_KERNEL_CPU_BACKWARD_BINARY_REDUCE_IMPL_H_