/** * Copyright (c) 2020 by Contributors * @file array/kernel_decl.h * @brief Sparse matrix format-specific operator declarations. */ #ifndef DGL_ARRAY_KERNEL_DECL_H_ #define DGL_ARRAY_KERNEL_DECL_H_ #include #include #include #include #include #include namespace dgl { namespace aten { /** * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); /** * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format with heterograph support. */ template void SpMMCsrHetero(const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_eid, const std::vector& out_eid); /** * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const aten::COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); /** * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format. */ template void SDDMMCsr(const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); /** * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format with heterograph support. */ template void SDDMMCsrHetero(const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_lhs, const std::vector& vec_rhs, std::vector vec_out, int lhs_target, int rhs_target, const std::vector& ufeat_eid, const std::vector& out_eid); /** * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format. */ template void SDDMMCoo(const std::string& op, const BcastOff& bcast, const aten::COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); /** * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format with heterograph support. */ template void SDDMMCooHetero(const std::string& op, const BcastOff& bcast, const std::vector& vec_coo, const std::vector& vec_lhs, const std::vector& vec_rhs, std::vector vec_out, int lhs_target, int rhs_target, const std::vector& lhs_eid, const std::vector& rhs_eid); /** * @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */ template void GatherMM(const NDArray A, const NDArray B, NDArray out, const NDArray idx_a, const NDArray idx_b); /** * @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */ template void GatherMMScatter(const NDArray A, const NDArray B, NDArray out, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); /** * @brief Generalized segmented dense Matrix-Matrix Multiplication. */ template void SegmentMM(const NDArray A, const NDArray B, NDArray out, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMMBackwardB(const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); /** * @brief Segment reduce. */ template void SegmentReduce(const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg); /** * @brief Scatter Add on first dimension. */ template void ScatterAdd(NDArray feat, NDArray idx, NDArray out); /** * @brief Update gradients for reduce operator max and min on first dimension. */ template void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out); /** * @brief Backward function of segment cmp. */ template void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out); /** * @brief Sparse-sparse matrix multiplication * * @param A The left operand. * @param A_weights The weights of matrix as a 1D tensor. * @param B The right operand. * @param B_weights The weights of matrix as a 1D tensor. * * @note GPU implementation will cast the indices to 32 bit. * @note The zero entries in the result are not removed. * @note The CSR matrix should not have duplicate entries. */ template std::pair CSRMM( const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B, NDArray B_weights); /** * @brief Sparse-sparse matrix summation. * * @param A The sparse matrices with the same size. * @param A_weights The weights of each sparse matrix as a 1D tensor. * * @note GPU implementation will cast the indices to 32 bit. * @note The zero entries in the result are not removed. * @note The CSR matrix should not have duplicate entries. */ template std::pair CSRSum( const std::vector& A, const std::vector& A_weights); /** * @brief Edge_softmax_csr forward function on Csr format. */ template void Edge_softmax_csr_forward(const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); /** * @brief Edge_softmax_csr backward function on Csr format. */ template void Edge_softmax_csr_backward(const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); } // namespace aten } // namespace dgl #endif // DGL_ARRAY_KERNEL_DECL_H_