/*! * Copyright (c) 2020 by Contributors * \file kernel/cpu/spmm.cc * \brief SPMM C APIs and definitions. */ #include "./spmm.h" #include namespace dgl { namespace aten { /*! \brief Generalized SpMM on Csr format. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { const int64_t dim = bcast.out_len; if (reduce == "sum") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); }); }); } else if (reduce == "max" || reduce == "min") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { DType *out_off = out.Ptr(); IdType* argX = Op::use_lhs ? static_cast(out_aux[0]->data) : nullptr; IdType* argW = Op::use_rhs ? static_cast(out_aux[1]->data) : nullptr; if (reduce == "max") { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } else { std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } }); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } /*! \brief Generalized SpMM on Csr format. */ template void SpMMCsrHetero(const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_ufeat, const std::vector& vec_efeat, std::vector* vec_out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids) { const int64_t dim = bcast.out_len; if (reduce == "sum") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { /* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = (*vec_out)[dst_id]; cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); } }); }); } else if (reduce == "max" || reduce == "min") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { std::vector updated((*vec_out).size(), false); // TODO(Israt): use vector updated to fill(out...) too for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr(); if (reduce == "max") std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max::zero); else std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min::zero); const dgl_type_t dst_id = out_node_tids[etype]; if (!updated[dst_id]) { updated[dst_id] = true; if (Op::use_lhs) { IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr(); std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1); } if (Op::use_rhs) { IdType *arge_etype = (*out_aux)[3][dst_id].Ptr(); std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1); } } } /* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype]; DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr(); NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = (*vec_out)[dst_id]; if (reduce == "max") { cpu::SpMMCmpCsrHetero>( bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); } else { cpu::SpMMCmpCsrHetero>( bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); } } }); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); /*! \brief Generalized SpMM on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { cpu::SpMMSumCoo(bcast, coo, ufeat, efeat, out); }); }); } else if (reduce == "max" || reduce == "min") { SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, { if (reduce == "max") cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); else cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); } // namespace aten } // namespace dgl