/** * Copyright (c) 2020 by Contributors * @file array/cuda/spmm.cu * @brief SPMM C APIs and definitions. */ #include #include "./spmm.cuh" #include "./ge_spmm.cuh" #include "./functor.cuh" #include "../../runtime/cuda/cuda_common.h" namespace dgl { using namespace cuda; namespace aten { /** * @brief CUDA implementation of g-SpMM on Csr format. * @note use cusparse if the reduce operator is `sum` and there is * no broadcast, use dgl's kernel in other cases. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0]; bool use_efeat = op != "copy_lhs"; if (reduce == "sum") { bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols); if (op == "copy_lhs" && cusparse_available(more_nnz)) { // cusparse int64_t x_length = 1; for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i]; CusparseCsrmm2( ufeat->ctx, csr, static_cast(ufeat->data), nullptr, static_cast(out->data), x_length); } else if (op == "mul" && is_scalar_efeat && cusparse_available(more_nnz)) { // cusparse int64_t x_length = 1; for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i]; if (!IsNullArray(csr.data)) { efeat = _IndexSelect(efeat, csr.data); } CusparseCsrmm2( ufeat->ctx, csr, static_cast(ufeat->data), static_cast(efeat->data), static_cast(out->data), x_length); } else { // general kernel SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, out, NullArray(), NullArray()); }); } } else if (reduce == "max") { SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else if (reduce == "min") { SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Not implemented"; } } /** * @brief CUDA implementation of g-SpMM on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_OP(op, Op, { cuda::SpMMCoo > ( bcast, coo, ufeat, efeat, out, NullArray(), NullArray()); }); } else if (reduce == "max") { SWITCH_OP(op, Op, { cuda::SpMMCoo > ( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else if (reduce == "min") { SWITCH_OP(op, Op, { cuda::SpMMCoo > ( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Not implemented"; } } template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); #if BF16_ENABLED template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); #endif // BF16_ENABLED template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); #if BF16_ENABLED template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); #endif // BF16_ENABLED template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); } // namespace aten } // namespace dgl