/*! * Copyright (c) 2020 by Contributors * \file array/cuda/sddmm.cu * \brief SDDMM C APIs and definitions. */ #include #include "./sddmm.cuh" namespace dgl { namespace aten { /*! * \brief CUDA implementation of g-SDDMM on heterograph using Csr format. */ template void SDDMMCsrHetero(const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_lhs, const std::vector& vec_rhs, std::vector vec_out, int lhs_target, int rhs_target, const std::vector& lhs_eid, const std::vector& rhs_eid) { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { /* Call SDDMM CUDA kernel for each relation type sequentially */ for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) { CSRMatrix csr = vec_csr[etype]; NDArray lhs = vec_lhs[lhs_eid[etype]]; NDArray rhs = vec_rhs[rhs_eid[etype]]; NDArray out = vec_out[etype]; cuda::SDDMMCsr( bcast, csr, lhs, rhs, out); } }); }); }); } template void SDDMMCsrHetero( const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& lhs, const std::vector& rhs, std::vector out, int lhs_target, int rhs_target, const std::vector& in_eid, const std::vector& out_eid); template void SDDMMCsrHetero( const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& lhs, const std::vector& rhs, std::vector out, int lhs_target, int rhs_target, const std::vector& in_eid, const std::vector& out_eid); template void SDDMMCsrHetero( const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& lhs, const std::vector& rhs, std::vector out, int lhs_target, int rhs_target, const std::vector& in_eid, const std::vector& out_eid); template void SDDMMCsrHetero( const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& lhs, const std::vector& rhs, std::vector out, int lhs_target, int rhs_target, const std::vector& in_eid, const std::vector& out_eid); template void SDDMMCsrHetero( const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& lhs, const std::vector& rhs, std::vector out, int lhs_target, int rhs_target, const std::vector& in_eid, const std::vector& out_eid); template void SDDMMCsrHetero( const std::string& op, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& lhs, const std::vector& rhs, std::vector out, int lhs_target, int rhs_target, const std::vector& in_eid, const std::vector& out_eid); } // namespace aten } // namespace dgl