/*! * Copyright (c) 2019 by Contributors * \file kernel/cuda/binary_reduce_sum.cu * \brief CUDA kernels for binary reduce sum */ #include #include "../../runtime/cuda/cuda_common.h" #include "./binary_reduce_impl.cuh" #include "./backward_binary_reduce_impl.cuh" #include "../utils.h" #include "../csr_interface.h" using minigun::advance::RuntimeConfig; namespace dgl { namespace kernel { namespace cuda { // specialization for cusparse #if CUDART_VERSION < 11000 template cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, int nnz, const DType* alpha, const cusparseMatDescr_t descrA, const DType* csrValA, const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb, const DType* beta, DType* C, int ldc) { LOG(INFO) << "Not supported dtype"; return CUSPARSE_STATUS_EXECUTION_FAILED; } template <> cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, int nnz, const float* alpha, const cusparseMatDescr_t descrA, const float* csrValA, const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb, const float* beta, float* C, int ldc) { return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc); } template <> cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, int nnz, const double* alpha, const cusparseMatDescr_t descrA, const double* csrValA, const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb, const double* beta, double* C, int ldc) { return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc); } #endif template cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const DType* alpha, const DType* A, int lda, const DType* beta, const DType* B, int ldb, DType* C, int ldc) { LOG(INFO) << "Not supported dtype"; return CUBLAS_STATUS_EXECUTION_FAILED; } template <> cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } template <> cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } template void CusparseCsrmm2( const RuntimeConfig& rtcfg, const aten::CSRMatrix& csr, const DType* B_data, DType* C_data, int x_length) { // We use csrmm2 to perform following operation: // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node // feature tensor. However, since cusparse only supports column-major, while our tensor // is stored in row-major, the actual computation is: // C = trans(A x trans(B)). // Currently, we use cublasXgeam to implement transposition and allocate intermediate // workspace memory for this. const int m = csr.num_rows; const int n = x_length; const int k = csr.num_cols; const int nnz = csr.indices->shape[0]; const DType alpha = 1.0; const DType beta = 0.0; // device auto device = runtime::DeviceAPI::Get(rtcfg.ctx); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); // allocate cusparse handle if needed if (!thr_entry->cusparse_handle) { CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); } CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, rtcfg.stream)); // allocate matrix for temporary transposed output DType* trans_out = static_cast(device->AllocWorkspace(rtcfg.ctx, m * n * sizeof(DType))); // all one data array DType* valptr = static_cast(device->AllocWorkspace(rtcfg.ctx, nnz * sizeof(DType))); utils::Fill(rtcfg.ctx, valptr, nnz, static_cast(1.)); #if CUDART_VERSION >= 11000 auto ctx = rtcfg.ctx; cusparseSpMatDescr_t matA; cusparseDnMatDescr_t matB, matC; constexpr auto cuda_dtype = std::is_same::value ? CUDA_R_32F: CUDA_R_64F; CUSPARSE_CALL(cusparseCreateCsr(&matA, m, k, nnz, static_cast(csr.indptr->data), static_cast(csr.indices->data), const_cast(valptr), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, cuda_dtype)); CUSPARSE_CALL(cusparseCreateDnMat(&matB, n, k, n, const_cast(B_data), cuda_dtype, CUSPARSE_ORDER_COL)); CUSPARSE_CALL(cusparseCreateDnMat(&matC, m, n, m, trans_out, cuda_dtype, CUSPARSE_ORDER_COL)); auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transB = CUSPARSE_OPERATION_TRANSPOSE; size_t workspace_size; CUSPARSE_CALL(cusparseSpMM_bufferSize( thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, cuda_dtype, CUSPARSE_CSRMM_ALG1, &workspace_size)); void* workspace = device->AllocWorkspace(ctx, workspace_size); CUSPARSE_CALL(cusparseSpMM( thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, cuda_dtype, CUSPARSE_CSRMM_ALG1, workspace)); device->FreeWorkspace(ctx, workspace); CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroyDnMat(matB)); CUSPARSE_CALL(cusparseDestroyDnMat(matC)); #else cusparseMatDescr_t descr; CUSPARSE_CALL(cusparseCreateMatDescr(&descr)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_CALL(Xcsrmm2( thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr, valptr, static_cast(csr.indptr->data), static_cast(csr.indices->data), B_data, n, &beta, trans_out, m)); CUSPARSE_CALL(cusparseDestroyMatDescr(descr)); #endif device->FreeWorkspace(rtcfg.ctx, valptr); // transpose the output matrix if (!thr_entry->cublas_handle) { CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); } CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, rtcfg.stream)); CUBLAS_CALL(Xgeam( thr_entry->cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, n, m, &alpha, trans_out, m, &beta, nullptr, n, C_data, n)); device->FreeWorkspace(rtcfg.ctx, trans_out); } // forward template void FallbackCallBinaryReduce( const RuntimeConfig& rtcfg, const CSRWrapper& graph, GData* gdata) { constexpr int XPU = kDLGPU; typedef int32_t Idx; typedef SelectSrc LeftSelector; typedef SelectNone RightSelector; typedef BinaryUseLhs BinaryOp; typedef ReduceSum Reducer; typedef cuda::FunctorsTempl Functors; typedef cuda::BinaryReduce UDF; // csr auto outcsr = graph.GetOutCSRMatrix(); minigun::Csr csr = utils::CreateCsr(outcsr.indptr, outcsr.indices); // If the user-given mapping is none and the target is edge data, we need to // replace the mapping by the edge ids in the csr graph so that the edge // data is correctly read/written. if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { gdata->lhs_mapping = static_cast(outcsr.data->data); } if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { gdata->rhs_mapping = static_cast(outcsr.data->data); } if (OutSelector::Type::target == binary_op::kEdge && gdata->out_mapping == nullptr) { gdata->out_mapping = static_cast(outcsr.data->data); } // TODO(minjie): allocator minigun::advance::Advance, UDF>( rtcfg, csr, gdata, minigun::IntArray1D()); } template void FallbackCallBackwardBinaryReduce( const RuntimeConfig& rtcfg, const CSRWrapper& graph, BackwardGData* gdata) { constexpr int XPU = kDLGPU; constexpr int Mode = binary_op::kGradLhs; typedef int32_t Idx; typedef SelectSrc LeftSelector; typedef SelectNone RightSelector; typedef BinaryUseLhs BinaryOp; typedef ReduceSum Reducer; // For backward computation, we use reverse csr and switch dst and src. // This benefits the most common src_op_edge or copy_src case, because the // gradients of src are now aggregated into destination buffer to reduce // competition of atomic add. auto incsr = graph.GetInCSRMatrix(); minigun::Csr csr = utils::CreateCsr(incsr.indptr, incsr.indices); typedef cuda::BackwardFunctorsTempl::Type, typename SwitchSrcDst::Type, BinaryOp, Reducer> Functors; typedef cuda::BackwardBinaryReduce UDF; // If the user-given mapping is none and the target is edge data, we need to // replace the mapping by the edge ids in the csr graph so that the edge // data is correctly read/written. if (LeftSelector::target == binary_op::kEdge && gdata->lhs_mapping == nullptr) { gdata->lhs_mapping = static_cast(incsr.data->data); } if (RightSelector::target == binary_op::kEdge && gdata->rhs_mapping == nullptr) { gdata->rhs_mapping = static_cast(incsr.data->data); } if (OutSelector::Type::target == binary_op::kEdge && gdata->out_mapping == nullptr) { gdata->out_mapping = static_cast(incsr.data->data); } // TODO(minjie): allocator minigun::advance::Advance, UDF>( rtcfg, csr, gdata, minigun::IntArray1D()); } } // namespace cuda template <> void CallBinaryReduce, ReduceSum>( const RuntimeConfig& rtcfg, const CSRWrapper& graph, GData* gdata) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { cuda::FallbackCallBinaryReduce(rtcfg, graph, gdata); } else { // cusparse use rev csr for csrmm auto csr = graph.GetInCSRMatrix(); cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data, gdata->x_length); } } template <> void CallBinaryReduce, ReduceSum>( const RuntimeConfig& rtcfg, const CSRWrapper& graph, GData* gdata) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { cuda::FallbackCallBinaryReduce(rtcfg, graph, gdata); } else { // cusparse use rev csr for csrmm auto csr = graph.GetInCSRMatrix(); cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data, gdata->x_length); } } // backward template <> void CallBackwardBinaryReduce, ReduceSum>( const RuntimeConfig& rtcfg, const CSRWrapper& graph, BackwardGData* gdata) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { cuda::FallbackCallBackwardBinaryReduce(rtcfg, graph, gdata); } else { auto csr = graph.GetOutCSRMatrix(); cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data, gdata->x_length); } } template <> void CallBackwardBinaryReduce, ReduceSum>( const RuntimeConfig& rtcfg, const CSRWrapper& graph, BackwardGData* gdata) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { cuda::FallbackCallBackwardBinaryReduce(rtcfg, graph, gdata); } else { auto csr = graph.GetOutCSRMatrix(); cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data, gdata->x_length); } } // generate definitions #define REDUCER ReduceSum #define XPU kDLGPU #define IDX int32_t EVAL(GEN_DTYPE, GEN_OP_TARGET, GEN_DEFINE); EVAL(GEN_BACKWARD_MODE, GEN_DTYPE, GEN_OP_TARGET, GEN_BACKWARD_DEFINE); } // namespace kernel } // namespace dgl