// !!! This is a file automatically generated by hipify!!! /** * Copyright (c) 2020 by Contributors * @file array/cuda/segment_reduce.cu * @brief Segment reduce C APIs and definitions. */ #include #include #include "functor.cuh" #include "segment_reduce.cuh" #include "utils.h" namespace dgl { using namespace cuda; namespace aten { template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg) { if (op == "sum") { cuda::SegmentReduce>( feat, offsets, out, arg); } else if (op == "max") { cuda::SegmentReduce>( feat, offsets, out, arg); } else if (op == "min") { cuda::SegmentReduce>( feat, offsets, out, arg); } else { LOG(FATAL) << "Not implemented"; } } template void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { cuda::ScatterAdd(feat, idx, out); } template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out) { cuda::UpdateGradMinMax_hetero( g, op, feat, idx, idx_etype, out); } template void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) { cuda::BackwardSegmentCmp(feat, arg, out); } template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); #if BF16_ENABLED template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); #endif // BF16_ENABLED template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void SegmentReduce( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); #if BF16_ENABLED template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); #endif // BF16_ENABLED template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void ScatterAdd( NDArray feat, NDArray idx, NDArray out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); #if BF16_ENABLED template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); #endif // BF16_ENABLED template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void UpdateGradMinMax_hetero( const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); #if BF16_ENABLED template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); #endif // BF16_ENABLED template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); template void BackwardSegmentCmp( NDArray feat, NDArray arg, NDArray out); } // namespace aten } // namespace dgl