/*! * Copyright (c) 2020 by Contributors * \file kernel/cpu/spmm.cc * \brief SPMM C APIs and definitions. */ #include "./spmm.h" #include namespace dgl { namespace aten { /*! \brief Generalized SpMM on Csr format. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { const int64_t dim = bcast.out_len; if (reduce == "sum") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); }); }); } else if (reduce == "max" || reduce == "min") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { DType *out_off = out.Ptr(); IdType* argX = Op::use_lhs ? static_cast(out_aux[0]->data) : nullptr; IdType* argW = Op::use_rhs ? static_cast(out_aux[1]->data) : nullptr; if (reduce == "max") { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } else { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } }); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } /*! \brief Generalized SpMM on Csr format. */ template void SpMMCsrHetero(const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_ufeat, const std::vector& vec_efeat, std::vector vec_out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids) { const int64_t dim = bcast.out_len; if (reduce == "sum") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { /* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = vec_out[dst_id]; cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); } }); }); } else if (reduce == "max" || reduce == "min") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { /* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype]; DType *out_off = vec_out[out_node_tids[etype]].Ptr(); NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = vec_out[dst_id]; if (reduce == "max") { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } else { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } } }); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector out, const std::vector& out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); /*! \brief Generalized SpMM on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { cpu::SpMMSumCoo(bcast, coo, ufeat, efeat, out); }); }); } else if (reduce == "max" || reduce == "min") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { if (reduce == "max") cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); else cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); } // namespace aten } // namespace dgl