/*! * Copyright (c) 2020 by Contributors * \file kernel/cpu/spmm.cc * \brief SPMM C APIs and definitions. */ #include "./spmm.h" #include namespace dgl { namespace aten { /*! \brief Generalized SpMM on Csr format. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_OP(op, Op, { cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); }); } else if (reduce == "max" || reduce == "min") { SWITCH_OP(op, Op, { if (reduce == "max") cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); else cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); /*! \brief Generalized SpMM on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_OP(op, Op, { cpu::SpMMSumCoo(bcast, coo, ufeat, efeat, out); }); } else if (reduce == "max" || reduce == "min") { SWITCH_OP(op, Op, { if (reduce == "max") cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); else cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); } // namespace aten } // namespace dgl