/** * Copyright (c) 2020 by Contributors * @file kernel/cpu/spmm.cc * @brief SPMM C APIs and definitions. */ #include "./spmm.h" #include namespace dgl { namespace aten { /** @brief Generalized SpMM on Csr format. */ template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { const int64_t dim = bcast.out_len; if (reduce == "sum") { SWITCH_OP(op, Op, { cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); }); } else if (reduce == "max" || reduce == "min") { SWITCH_OP(op, Op, { DType* out_off = out.Ptr(); if (reduce == "max") { std::fill( out_off, out_off + csr.num_rows * dim, cpu::op::Max::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } else { std::fill( out_off, out_off + csr.num_rows * dim, cpu::op::Min::zero); cpu::SpMMCmpCsr>( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); } }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } /** @brief Generalized SpMM on Csr format. */ template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_ufeat, const std::vector& vec_efeat, std::vector* vec_out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids) { const int64_t dim = bcast.out_len; if (reduce == "sum") { SWITCH_OP(op, Op, { /* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = (*vec_out)[dst_id]; cpu::SpMMSumCsr(bcast, csr, ufeat, efeat, out); } }); } else if (reduce == "max" || reduce == "min") { SWITCH_OP(op, Op, { std::vector updated((*vec_out).size(), false); // TODO(Israt): use vector updated to fill(out...) too for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr(); if (reduce == "max") std::fill( out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max::zero); else std::fill( out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min::zero); const dgl_type_t dst_id = out_node_tids[etype]; if (!updated[dst_id]) { updated[dst_id] = true; if (Op::use_lhs) { IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr(); std::fill( argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1); } if (Op::use_rhs) { IdType* arge_etype = (*out_aux)[3][dst_id].Ptr(); std::fill( arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1); } } } /* Call SpMM for each relation type */ for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype]; CSRMatrix csr = vec_csr[etype]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray out = (*vec_out)[dst_id]; if (reduce == "max") { cpu::SpMMCmpCsrHetero>( bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); } else { cpu::SpMMCmpCsrHetero>( bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); } } }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_node_tids, const std::vector& out_node_tids); /** @brief Edge_softmax_csr forward op on Csr format. */ template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out) { SWITCH_OP(op, Op, { cpu::Edge_softmax_csr_forward( bcast, csr, ufeat, efeat, out); }); } /** @brief Edge_softmax_csr backward op on Csr format. */ template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray out, NDArray sds, NDArray back_out) { SWITCH_OP(op, Op, { cpu::Edge_softmax_csr_backward( bcast, csr, out, sds, back_out); }); } template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_forward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); template void Edge_softmax_csr_backward( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out); /** @brief Generalized SpMM on Coo format. */ template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_OP(op, Op, { cpu::SpMMSumCoo(bcast, coo, ufeat, efeat, out); }); } else if (reduce == "max" || reduce == "min") { SWITCH_OP(op, Op, { if (reduce == "max") cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); else cpu::SpMMCmpCoo>( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; } } template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); } // namespace aten } // namespace dgl