/** * Copyright (c) 2020 by Contributors * @file kernel/cpu/gaher_mm.cc * @brief GatherMM C APIs and definitions. */ #include "./gather_mm.h" #include namespace dgl { namespace aten { /** @brief Generalized SegmentMM. */ template void SegmentMM(const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans) { LOG(FATAL) << "Unsupported CPU kernel for SegmentMM."; } template void SegmentMMBackwardB(const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) { LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB."; } /** @brief Generalized GatherMM. */ template void GatherMM(const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b) { LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; } /** @brief Generalized GatherMM_scatter. */ template void GatherMMScatter(const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c) { LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; } template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool a_trans, bool b_trans); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); template void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); } // namespace aten } // namespace dgl