/*! * Copyright (c) 2020 by Contributors * \file kernel/cpu/segment_reduce.cc * \brief Segment reduce C APIs and definitions. */ #include "./segment_reduce.h" #include #include #include "./spmm_binary_ops.h" namespace dgl { namespace aten { /*! \brief Segment Reduce operator. */ template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg) { if (op == "sum") { SWITCH_BITS(bits, DType, { cpu::SegmentSum(feat, offsets, out); }); } else if (op == "max" || op == "min") { if (op == "max") { SWITCH_BITS(bits, DType, { cpu::SegmentCmp>( feat, offsets, out, arg); }); } else { SWITCH_BITS(bits, DType, { cpu::SegmentCmp>( feat, offsets, out, arg); }); } } else { LOG(FATAL) << "Unsupported reduce function " << op; } } /*! \brief Scatter Add.*/ template void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { SWITCH_BITS(bits, DType, { cpu::ScatterAdd(feat, idx, out); }); } /*! \brief Update gradients for reduce operator max/min on heterogeneous graph.*/ template void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out) { SWITCH_BITS(bits, DType, { cpu::UpdateGradMinMax_hetero(g, op, feat, idx, idx_etype, out); }); } /*! \brief Backward function of segment cmp.*/ template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out) { SWITCH_BITS(bits, DType, { cpu::BackwardSegmentCmp(feat, arg, out); }); } template void SegmentReduce( const std::string &op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string &op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string &op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string &op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string &op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string &op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray arg, NDArray out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); } // namespace aten } // namespace dgl