/*! * Copyright (c) 2020 by Contributors * \file array/cuda/spmm.cu * \brief SPMM C APIs and definitions. */ #include #include "./spmm.cuh" #include "./ge_spmm.cuh" #include "./functor.cuh" #include "../../runtime/cuda/cuda_common.h" namespace dgl { using namespace cuda; namespace aten { /*! * \brief Determine whether cusparse SpMM function is applicable. */ template inline bool cusparse_available(bool more_nnz_than_matrix_size) { #if CUDART_VERSION < 11000 if (std::is_same::value) if (bits > 16) return true; return false; #else if (bits == 16) return false; // cusparse's SpMM on fp16 is slow, temporally disabled. // If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1. return !more_nnz_than_matrix_size; #endif } /*! * \brief CUDA implementation of g-SpMM on Csr format. * \note use cusparse if the reduce operator is `sum` and there is * no broadcast, use dgl's kernel in other cases. */ template void SpMMCsrHetero(const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& vec_csr, const std::vector& vec_ufeat, const std::vector& vec_efeat, std::vector* vec_out, std::vector>* out_aux, const std::vector& ufeat_ntids, // ufeat node type id const std::vector& out_ntids) { // output node type id bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0]; bool use_efeat = op != "copy_lhs"; auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); SWITCH_BITS(bits, DType, { std::vector trans_out((*vec_out).size(), NULL); bool use_legacy_cusparsemm = // (CUDART_VERSION < 11000) && (reduce == "sum") && (CUDART_VERSION_LT_11000) && (reduce == "sum") && // legacy cuSPARSE does not care about NNZ, hence the argument "false". ((op == "copy_lhs" && cusparse_available(false)) || (op == "mul" && is_scalar_efeat && cusparse_available(false))); // Create temporary output buffer to store non-transposed output if (use_legacy_cusparsemm) { for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) { const int m = (*vec_out)[ntype]->shape[0]; const int n = (*vec_out)[ntype]->shape[1]; if (m == 0) continue; DType *out = static_cast(device->AllocWorkspace(vec_csr[0].indptr->ctx, m * n * sizeof(DType))); CUDA_CALL(hipMemset(out, 0, m * n * sizeof(DType))); trans_out[ntype] = out; } } // Check shape of ufeat for all relation type and compute feature size int64_t x_length = 1; for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) { NDArray ufeat = vec_ufeat[ufeat_ntids[etype]]; NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]]; CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes"; for (int i = 1; i < ufeat->ndim; ++i) { if (ufeat->shape[i] != next_ufeat->shape[i]) { if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1) LOG(FATAL) << "Homogenized message passing on heterogeneous graphs does not support " << "automatic broadcasting. Please manually broadcast it before calling " << "message passing functions."; else LOG(FATAL) << "Input features have different shapes."; return; } if (etype == 0) x_length *= ufeat->shape[i]; } } // TODO(Israt): Can python do the following initializations while creating the tensors? if (reduce == "max" || reduce == "min") { const int64_t dim = bcast.out_len; std::vector updated((*vec_out).size(), false); for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { DType *out_off = (*vec_out)[out_ntids[etype]].Ptr(); if (reduce == "max") _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max::zero()); else // min _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min::zero()); const dgl_type_t dst_id = out_ntids[etype]; if (!updated[dst_id]) { updated[dst_id] = true; if (op == "copy_lhs") { IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr(); _Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast(-1)); } if (op == "copy_rhs") { IdType *arge_etype = (*out_aux)[3][dst_id].Ptr(); _Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast(-1)); } } } } hipStream_t stream = runtime::getCurrentCUDAStream(); for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { const dgl_type_t src_id = ufeat_ntids[etype]; const dgl_type_t dst_id = out_ntids[etype]; CSRMatrix csr = vec_csr[etype]; if (reduce == "sum") { bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols); /* Call SpMM for each relation type */ if (op == "copy_lhs" && cusparse_available(more_nnz)) { // cusparse /* If CUDA is less than 11.0, put the output in trans_out for later transposition */ // DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] : DType *out = (CUDART_VERSION_LT_11000) ? trans_out[dst_id] : static_cast((*vec_out)[dst_id]->data); CusparseCsrmm2Hetero( csr.indptr->ctx, csr, static_cast(vec_ufeat[src_id]->data), nullptr, out, x_length, stream); } else if (op == "mul" && is_scalar_efeat && cusparse_available(more_nnz)) { // cusparse NDArray efeat = vec_efeat[etype]; if (!IsNullArray(csr.data)) efeat = _IndexSelect(efeat, csr.data); CusparseCsrmm2Hetero( csr.indptr->ctx, csr, static_cast(vec_ufeat[src_id]->data), static_cast(efeat->data), // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11 static_cast((*vec_out)[dst_id]->data), x_length, stream); } else { // general kernel NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray()); }); } } else if (reduce == "max") { SWITCH_OP(op, Op, { NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; cuda::SpMMCmpCsrHetero >( bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); }); } else if (reduce == "min") { SWITCH_OP(op, Op, { NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; cuda::SpMMCmpCsrHetero >( bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); }); } else { LOG(FATAL) << "Not implemented"; } } if (use_legacy_cusparsemm) { // transpose output for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) { const int m = (*vec_out)[ntype]->shape[0]; const int n = (*vec_out)[ntype]->shape[1]; if (m == 0) continue; DType *C_data = static_cast((*vec_out)[ntype]->data); _Transpose(trans_out[ntype], C_data, n, m); device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]); } } }); } template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_ntids, const std::vector& out_ntids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_ntids, const std::vector& out_ntids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_ntids, const std::vector& out_ntids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_ntids, const std::vector& out_ntids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_ntids, const std::vector& out_ntids); template void SpMMCsrHetero( const std::string& op, const std::string& reduce, const BcastOff& bcast, const std::vector& csr, const std::vector& ufeat, const std::vector& efeat, std::vector* out, std::vector>* out_aux, const std::vector& ufeat_ntids, const std::vector& out_ntids); } // namespace aten } // namespace dgl