/*! * Copyright (c) 2020 by Contributors * \file array/cuda/spmm.cu * \brief SPMM C APIs and definitions. */ #include #include "./spmm.cuh" #include "./functor.cuh" #include "../../runtime/cuda/cuda_common.h" namespace dgl { using namespace cuda; namespace aten { namespace { /*! \brief Fill the vector started from ptr of size length with val */ template void _Fill(DType* ptr, size_t length, DType val) { auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); int nt = FindNumThreads(length); int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound. CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, thr_entry->stream, ptr, length, val); } } // namespace namespace cusparse { #if CUDART_VERSION < 11000 template cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, int nnz, const DType* alpha, const cusparseMatDescr_t descrA, const DType* csrValA, const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb, const DType* beta, DType* C, int ldc) { LOG(INFO) << "Not supported dtype"; return CUSPARSE_STATUS_EXECUTION_FAILED; } template <> cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, int nnz, const float* alpha, const cusparseMatDescr_t descrA, const float* csrValA, const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb, const float* beta, float* C, int ldc) { return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc); } template <> cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, int nnz, const double* alpha, const cusparseMatDescr_t descrA, const double* csrValA, const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb, const double* beta, double* C, int ldc) { return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc); } #endif template cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const DType* alpha, const DType* A, int lda, const DType* beta, const DType* B, int ldb, DType* C, int ldc) { LOG(INFO) << "Not supported dtype"; return CUBLAS_STATUS_EXECUTION_FAILED; } template <> cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } template <> cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } /*! Cusparse implementation of SpMM on Csr format. */ template void CusparseCsrmm2( const DLContext& ctx, const CSRMatrix& csr, const DType* B_data, const DType* A_data, DType* C_data, int x_length) { // We use csrmm2 to perform following operation: // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node // feature tensor. However, since cusparse only supports column-major, while our tensor // is stored in row-major, the actual computation is: // C = trans(A x trans(B)). // Currently, we use cublasXgeam to implement transposition and allocate intermediate // workspace memory for this. const int m = csr.num_rows; const int n = x_length; const int k = csr.num_cols; const int nnz = csr.indices->shape[0]; const DType alpha = 1.0; const DType beta = 0.0; // device auto device = runtime::DeviceAPI::Get(ctx); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); // allocate cusparse handle if needed if (!thr_entry->cusparse_handle) { CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); } CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); // allocate matrix for temporary transposed output DType* trans_out = static_cast(device->AllocWorkspace(ctx, m * n * sizeof(DType))); // all one data array DType* valptr = nullptr; if (!A_data) { valptr = static_cast(device->AllocWorkspace(ctx, nnz * sizeof(DType))); _Fill(valptr, nnz, static_cast(1.)); } #if CUDART_VERSION >= 11000 cusparseSpMatDescr_t matA; cusparseDnMatDescr_t matB, matC; constexpr auto cuda_dtype = std::is_same::value ? CUDA_R_32F: CUDA_R_64F; CUSPARSE_CALL(cusparseCreateCsr(&matA, m, k, nnz, static_cast(csr.indptr->data), static_cast(csr.indices->data), const_cast(valptr? valptr : A_data), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, cuda_dtype)); CUSPARSE_CALL(cusparseCreateDnMat(&matB, n, k, n, const_cast(B_data), cuda_dtype, CUSPARSE_ORDER_COL)); CUSPARSE_CALL(cusparseCreateDnMat(&matC, m, n, m, trans_out, cuda_dtype, CUSPARSE_ORDER_COL)); auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transB = CUSPARSE_OPERATION_TRANSPOSE; size_t workspace_size; CUSPARSE_CALL(cusparseSpMM_bufferSize( thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, cuda_dtype, CUSPARSE_CSRMM_ALG1, &workspace_size)); void* workspace = device->AllocWorkspace(ctx, workspace_size); CUSPARSE_CALL(cusparseSpMM( thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, cuda_dtype, CUSPARSE_CSRMM_ALG1, workspace)); device->FreeWorkspace(ctx, workspace); CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroyDnMat(matB)); CUSPARSE_CALL(cusparseDestroyDnMat(matC)); #else cusparseMatDescr_t descr; CUSPARSE_CALL(cusparseCreateMatDescr(&descr)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_CALL(Xcsrmm2( thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr, (valptr)? valptr : A_data, static_cast(csr.indptr->data), static_cast(csr.indices->data), B_data, n, &beta, trans_out, m)); CUSPARSE_CALL(cusparseDestroyMatDescr(descr)); #endif if (valptr) device->FreeWorkspace(ctx, valptr); // transpose the output matrix if (!thr_entry->cublas_handle) CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, thr_entry->stream)); CUBLAS_CALL(Xgeam( thr_entry->cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, n, m, &alpha, trans_out, m, &beta, nullptr, n, C_data, n)); device->FreeWorkspace(ctx, trans_out); } } // namespace cusparse #define SWITCH_OP(op, Op, ...) \ do { \ if ((op) == "add") { \ typedef cuda::binary::Add Op; \ { __VA_ARGS__ } \ } else if ((op) == "sub") { \ typedef cuda::binary::Sub Op; \ { __VA_ARGS__ } \ } else if ((op) == "mul") { \ typedef cuda::binary::Mul Op; \ { __VA_ARGS__ } \ } else if ((op) == "div") { \ typedef cuda::binary::Div Op; \ { __VA_ARGS__ } \ } else if ((op) == "copy_lhs") { \ typedef cuda::binary::CopyLhs Op; \ { __VA_ARGS__ } \ } else if ((op) == "copy_rhs") { \ typedef cuda::binary::CopyRhs Op; \ { __VA_ARGS__ } \ } else { \ LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \ } \ } while (0) /*! * \brief CUDA implementation of g-SpMM on Csr format. * \note use cusparse if the reduce operator is `sum` and there is * no broadcast, use dgl's kernel in other cases. */ template void SpMMCsr(const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { if (sizeof(IdType) == 4 && op == "copy_lhs") { int64_t x_length = 1; for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i]; cusparse::CusparseCsrmm2( ufeat->ctx, csr, static_cast(ufeat->data), nullptr, static_cast(out->data), x_length); } else if (sizeof(IdType) == 4 && op == "mul" && efeat.NumElements() == csr.indices->shape[0]) { int64_t x_length = 1; for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i]; if (!IsNullArray(csr.data)) efeat = IndexSelect(efeat, csr.data); cusparse::CusparseCsrmm2( ufeat->ctx, csr, static_cast(ufeat->data), static_cast(efeat->data), static_cast(out->data), x_length); } else { SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, out, NullArray(), NullArray()); }); } } else if (reduce == "max") { SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else if (reduce == "min") { SWITCH_OP(op, Op, { cuda::SpMMCsr >( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Not implemented"; } } /*! * \brief CUDA implementation of g-SpMM on Coo format. */ template void SpMMCoo(const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { if (reduce == "sum") { SWITCH_OP(op, Op, { cuda::SpMMCoo > ( bcast, coo, ufeat, efeat, out, NullArray(), NullArray()); }); } else if (reduce == "max") { SWITCH_OP(op, Op, { cuda::SpMMCoo > ( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else if (reduce == "min") { SWITCH_OP(op, Op, { cuda::SpMMCoo > ( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); }); } else { LOG(FATAL) << "Not implemented"; } } template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCsr( const std::string& op, const std::string& reduce, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); template void SpMMCoo( const std::string& op, const std::string& reduce, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux); } // namespace aten } // namespace dgl