/** * Copyright (c) 2020 by Contributors * @file array/cuda/sddmm.cu * @brief SDDMM C APIs and definitions. */ #include #include "./sddmm.cuh" #include "./functor.cuh" namespace dgl { namespace aten { /** * @brief CUDA implementation of g-SDDMM on Csr format. */ template void SDDMMCsr(const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) { SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { cuda::SDDMMCsr(bcast, csr, lhs, rhs, out); }); }); } /** * @brief CUDA implementation of g-SDDMM on Coo format. */ template void SDDMMCoo(const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) { SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { cuda::SDDMMCoo(bcast, coo, lhs, rhs, out); }); }); } template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); #if BF16_ENABLED template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); #endif // BF16_ENABLED template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCsr( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); #if BF16_ENABLED template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); #endif // BF16_ENABLED template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); template void SDDMMCoo( const std::string& op, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target); } // namespace aten } // namespace dgl