/*! * Copyright (c) 2020 by Contributors * \file kernel/cpu/gaher_mm.cc * \brief GatherMM C APIs and definitions. */ #include "./gather_mm.h" #include namespace dgl { namespace aten { /*! \brief Generalized SegmentMM. */ template void SegmentMM(const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans) { LOG(FATAL) << "Unsupported CPU kernel for SegmentMM."; } template void SegmentMMBackwardB(const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) { LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB."; } /*! \brief Generalized GatherMM. */ template void GatherMM(const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b) { LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; } /*! \brief Generalized GatherMM_scatter. */ template void GatherMMScatter(const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c) { LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; } template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); } // namespace aten } // namespace dgl