/*! * Copyright (c) 2020 by Contributors * \file array/kernel_decl.h * \brief Sparse matrix format-specific operator declarations. */ #ifndef DGL_ARRAY_KERNEL_DECL_H_ #define DGL_ARRAY_KERNEL_DECL_H_ #include #include #include #include #include #include namespace dgl { namespace aten { /*! * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); /*! * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format with heterograph support. */ template void SpMMCsrHetero(const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_eid, const std::vector& out_eid); /*! * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const aten::COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); /*! * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format. */ template void SDDMMCsr(const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); /*! * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format with heterograph support. */ template void SDDMMCsrHetero(const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_lhs, const std::vector& vec_rhs, std::vector vec_out, int lhs_target, int rhs_target, const std::vector& ufeat_eid, const std::vector& out_eid); /*! * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format. */ template void SDDMMCoo(const std::string& op, const BcastOff& bcast, const aten::COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); /*! * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format with heterograph support. */ template void SDDMMCooHetero(const std::string& op, const BcastOff& bcast, const std::vector& vec_coo, const std::vector& vec_lhs, const std::vector& vec_rhs, std::vector vec_out, int lhs_target, int rhs_target, const std::vector& lhs_eid, const std::vector& rhs_eid); /*! * \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */ template void GatherMM(const NDArray A, const NDArray B, NDArray out, const NDArray idx_a, const NDArray idx_b); /*! * \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */ template void GatherMMScatter(const NDArray A, const NDArray B, NDArray out, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); /*! * \brief Generalized segmented dense Matrix-Matrix Multiplication. */ template void SegmentMM(const NDArray A, const NDArray B, NDArray out, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMMBackwardB(const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); /*! * \brief Segment reduce. */ template void SegmentReduce(const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); /*! * \brief Scatter Add on first dimension. */ template void ScatterAdd(NDArray feat, NDArray idx, NDArray out); /*! * \brief Update gradients for reduce operator max and min on first dimension. */ template void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); /*! * \brief Backward function of segment cmp. */ template void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out); /*! * \brief Sparse-sparse matrix multiplication * * \param A The left operand. * \param A_weights The weights of matrix as a 1D tensor. * \param B The right operand. * \param B_weights The weights of matrix as a 1D tensor. * * \note GPU implementation will cast the indices to 32 bit. * \note The zero entries in the result are not removed. * \note The CSR matrix should not have duplicate entries. */ template std::pair CSRMM( const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B, NDArray B_weights); /*! * \brief Sparse-sparse matrix summation. * * \param A The sparse matrices with the same size. * \param A_weights The weights of each sparse matrix as a 1D tensor. * * \note GPU implementation will cast the indices to 32 bit. * \note The zero entries in the result are not removed. * \note The CSR matrix should not have duplicate entries. */ template std::pair CSRSum( const std::vector& A, const std::vector& A_weights); /*! * \brief Edge_softmax_csr forward function on Csr format. */ template void Edge_softmax_csr_forward(const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); /*! * \brief Edge_softmax_csr backward function on Csr format. */ template void Edge_softmax_csr_backward(const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); } // namespace aten } // namespace dgl #endif // DGL_ARRAY_KERNEL_DECL_H_