/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cublasMMWrapper.h" #include "cuda_utils.h" #include "src/turbomind/macro.h" #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #endif namespace turbomind { cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map, std::mutex* mu, IAllocator* allocator): cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), allocator_(allocator) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (allocator_ != nullptr) { cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); } // hgemm-switch 0:fp32r,1:fp16r-fp32r,2:fp16r ----xzhou 20240427 m_ihgemm_switch = 0; const char* env_var_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH"); if (env_var_value_str != nullptr) { m_ihgemm_switch = std::stoi(env_var_value_str); } m_ihgemm_switch_n = 16; const char* env_n_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH_N"); if (env_n_value_str != nullptr) { m_ihgemm_switch_n = std::stoi(env_n_value_str); } if(m_ihgemm_switch != 0) printf("hgemm_switch=%d, hgemm_switch_n_limit=%d\n", m_ihgemm_switch, m_ihgemm_switch_n); } #ifdef SPARSITY_ENABLED cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, cusparseLtHandle_t cusparselt_handle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map, std::mutex* mu, IAllocator* allocator): cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), cusparselt_handle_(cusparselt_handle), stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), allocator_(allocator) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (allocator_ != nullptr) { cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); } } #endif cublasMMWrapper::~cublasMMWrapper() { TM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_ = nullptr; if (allocator_ != nullptr) { allocator_->free((void**)(&cublas_workspace_)); allocator_ = nullptr; } } cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper): cublas_handle_(wrapper.cublas_handle_), cublaslt_handle_(wrapper.cublaslt_handle_), #ifdef SPARSITY_ENABLED cusparselt_handle_(wrapper.cusparselt_handle_), #endif stream_(wrapper.stream_), cublas_algo_map_(wrapper.cublas_algo_map_), mu_(wrapper.mu_), allocator_(wrapper.allocator_) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (allocator_ != nullptr) { cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); } } void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* alpha, const void* A, cudaDataType_t Atype, int lda, const void* B, cudaDataType_t Btype, int ldb, const void* beta, void* C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_->lock(); // hgemm-switch ----xzhou 20240427 if(m_ihgemm_switch == 1 && (m == 5120 || m == 4096 || m == 12288 || m == 11008) && n <= m_ihgemm_switch_n && Atype == CUDA_R_16F){ computeType = CUDA_R_16F; } check_cuda_error(cublasGemmEx(cublas_handle_, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); sync_check_cuda_error(); mu_->unlock(); } void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); } void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); half h_alpha = (half)(f_alpha); half h_beta = (half)(f_beta); mu_->lock(); // TODO: default cublas libs cudaDataType_t computeType = computeType_; // hgemm-switch ------xzhou 20240427 if(m_ihgemm_switch == 1 && (m == 5120 || m == 4096 || m == 12288 || m == 11008) && n <= m_ihgemm_switch_n && Atype_ == CUDA_R_16F){ computeType = CUDA_R_16F; } int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; int batch_count = 1; // fp32 use cublas as default // fp16 use cublasLt as default const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(Atype_)); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); if (findAlgo) { if (info.stages != -1) { using_cublasLt = false; } else { using_cublasLt = false; } } // if (using_cublasLt) { // if (0) { // cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cudaDataType_t scaleType; // #if (CUDART_VERSION >= 11000) // cublasComputeType_t computeType; // #else // cudaDataType_t computeType; // #endif // if (is_fp16_computeType) { // #if (CUDART_VERSION >= 11000) // computeType = CUBLAS_COMPUTE_16F; // #else // computeType = CUDA_R_16F; // #endif // scaleType = CUDA_R_16F; // } // else { // #if (CUDART_VERSION >= 11000) // computeType = CUBLAS_COMPUTE_32F; // #else // computeType = CUDA_R_32F; // #endif // scaleType = CUDA_R_32F; // } // // -------------------------------------- // // Create descriptors for the original matrices // cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); // cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); // cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc); // #if (CUDART_VERSION >= 11000) // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // #else // cublasLtMatmulDescCreate(&operationDesc, computeType); // #endif // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); // cublasLtMatmulAlgo_t algo; // void* workSpace = cublas_workspace_; // int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; // if (findAlgo) { // if (info.workspaceSize > workspaceSize) { // findAlgo = 0; // } // else { // cublasLtMatmulAlgoInit( // cublaslt_handle_, computeType, scaleType, Atype_, Btype_, Ctype_, Ctype_, info.algoId, &algo); // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, // CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, // &(info.reductionScheme), // sizeof(info.reductionScheme)); // // #if (CUDART_VERSION >= 11000) // // cublasLtMatmulAlgoConfigSetAttribute( // // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // // #endif // #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, // CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, // &(info.cluster_shapeId), // sizeof(info.cluster_shapeId)); // #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)); // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)); // cublasLtMatmulAlgoConfigSetAttribute( // &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)); // #endif // } // } // // cublasLtMatmul(cublaslt_handle_, // // operationDesc, // // alpha, // // A, // // Adesc, // // B, // // Bdesc, // // beta, // // C, // // Cdesc, // // C, // // Cdesc, // // (findAlgo == 1 ? (&algo) : NULL), // // workSpace, // // workspaceSize, // // stream_); // cublasLtMatmulDescDestroy(operationDesc); // cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Cdesc); // sync_check_cuda_error(); // } // else { int cublasAlgo = info.algoId; check_cuda_error(cublasGemmEx(cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, ldb, beta, C, Ctype_, ldc, computeType, static_cast(cublasAlgo))); sync_check_cuda_error(); // } mu_->unlock(); } void cublasMMWrapper::setFP32GemmConfig() { Atype_ = CUDA_R_32F; Btype_ = CUDA_R_32F; Ctype_ = CUDA_R_32F; computeType_ = CUDA_R_32F; } void cublasMMWrapper::setFP16GemmConfig() { Atype_ = CUDA_R_16F; Btype_ = CUDA_R_16F; Ctype_ = CUDA_R_16F; computeType_ = CUDA_R_32F; } #ifdef ENABLE_BF16 void cublasMMWrapper::setBF16GemmConfig() { Atype_ = CUDA_R_16BF; Btype_ = CUDA_R_16BF; Ctype_ = CUDA_R_16BF; computeType_ = CUDA_R_32F; } #endif void cublasMMWrapper::setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType) { Atype_ = aType; Btype_ = bType; Ctype_ = cType; computeType_ = computeType; } CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) { if (data_type == CUDA_R_16F) { return HALF_DATATYPE; } else if (data_type == CUDA_R_32F) { return FLOAT_DATATYPE; } #ifdef ENABLE_BF16 else if (data_type == CUDA_R_16BF) { return BFLOAT16_DATATYPE; } #endif return FLOAT_DATATYPE; } // #if (CUDART_VERSION >= 11000) // // input, weight, output are row-major // // only works for cublas 11.x // void cublasMMWrapper::Gemm(cublasOperation_t transa, // cublasOperation_t transb, // const int m, // const int n, // const int k, // const void* A, // const int lda, // const void* B, // const int ldb, // const void* bias, // void* C, // const int ldc) // { // TM_LOG_DEBUG(__PRETTY_FUNCTION__); // cudaDataType_t Atype, Btype, Ctype; // cublasComputeType_t computeType; // cudaDataType_t scaleType; // float alpha_float = 1.0f; // float beta_float = 0.0f; // half alpha_half = half(1.0f); // half beta_half = half(0.0f); // void * alpha, *beta; // // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; // if (Atype_ == CUDA_R_32F) { // computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // Atype = CUDA_R_32F; // Btype = CUDA_R_32F; // Ctype = CUDA_R_32F; // scaleType = CUDA_R_32F; // alpha = &alpha_float; // beta = &beta_float; // } // else if (Atype_ == CUDA_R_16BF) { // computeType = CUBLAS_COMPUTE_32F_FAST_TF32; // Atype = CUDA_R_16BF; // Btype = CUDA_R_16BF; // Ctype = CUDA_R_16BF; // scaleType = CUDA_R_32F; // alpha = &alpha_float; // beta = &beta_float; // } // else { // computeType = CUBLAS_COMPUTE_16F; // Atype = CUDA_R_16F; // Btype = CUDA_R_16F; // Ctype = CUDA_R_16F; // scaleType = CUDA_R_16F; // alpha = &alpha_half; // beta = &beta_half; // } // cublasLtMatmulDesc_t operationDesc = NULL; // cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; // cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda); // cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb); // cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); // cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); // cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); // // check_cuda_error(cublasLtMatmul( // // cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_)); // cublasLtMatrixLayoutDestroy(Adesc); // cublasLtMatrixLayoutDestroy(Bdesc); // cublasLtMatrixLayoutDestroy(Cdesc); // cublasLtMatmulDescDestroy(operationDesc); // } // #endif void cublasMMWrapper::setStream(cudaStream_t stream) { stream_ = stream; } void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batch_count, const float f_alpha, const float f_beta) { half h_alpha = (half)f_alpha; half h_beta = (half)f_beta; mu_->lock(); int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, strideA, B, Btype_, ldb, strideB, beta, C, Ctype_, ldc, strideC, batch_count, computeType_, static_cast(info.algoId))); mu_->unlock(); } void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType, const int ldc, const int64_t strideC, const int batch_count, cudaDataType_t computeType) { half h_alpha = (half)f_alpha; half h_beta = (half)f_beta; mu_->lock(); int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle_, transa, transb, m, n, k, alpha, A, AType, lda, strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batch_count, computeType, static_cast(info.algoId))); mu_->unlock(); } void cublasMMWrapper::batchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* const* A, const int lda, const void* const* B, const int ldb, void* const* C, const int ldc, const int batch_count) { float f_alpha = static_cast(1.0f); float f_beta = static_cast(0.0f); half h_alpha = (half)1.0f; half h_beta = (half)0.0f; mu_->lock(); int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); check_cuda_error(cublasGemmBatchedEx(cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, ldb, beta, C, Ctype_, ldc, batch_count, computeType_, static_cast(info.algoId))); mu_->unlock(); } bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, const int k, const int n) { CublasDataType data_type = getCublasDataType(Atype_); if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false || cublas_algo_map_->isExist(1, m, k, n, data_type) == false) { return false; } else { return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type).exec_time < 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time; } } #ifdef SPARSITY_ENABLED void cublasMMWrapper::SpGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, const void* B, void* C) { if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) { throw std::runtime_error("\n[TM][ERROR] sparse GEMM only supports FP16 data type now."); } static bool not_printed_fp32_accumulation_warning = true; if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) { printf("[TM][WARNING] cublasMMWrapper sets to FP32 compute type, " "but sparse gemm will use FP16 compute type since cusparselt " "supports FP16 accumulation only.\n"); not_printed_fp32_accumulation_warning = false; } cusparseOrder_t order = CUSPARSE_ORDER_COL; cusparseOperation_t opA = (transa == CUBLAS_OP_N) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; cusparseOperation_t opB = (transb == CUBLAS_OP_N) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatmulPlan_t plan; bool is_rowmajor = (order == CUSPARSE_ORDER_ROW); bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE); bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE); auto num_A_rows = (isA_transposed) ? k : m; auto num_A_cols = (isA_transposed) ? m : k; auto num_B_rows = (isB_transposed) ? n : k; auto num_B_cols = (isB_transposed) ? k : n; auto num_C_rows = m; auto num_C_cols = n; unsigned alignment = 16; auto lda = (is_rowmajor) ? num_A_cols : num_A_rows; auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows; auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows; float _alpha(1.0f); float _beta(0.0f); char mark[256]; sprintf(mark, "%d_%d_%d_%d", 1, m, n, k); if (sp_mat_A_desc_map_.find(mark) != sp_mat_A_desc_map_.end()) { CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(&cusparselt_handle_, &matmul, opA, opB, &sp_mat_A_desc_map_[mark], &sp_mat_B_desc_map_[mark], &sp_mat_C_desc_map_[mark], &sp_mat_C_desc_map_[mark], compute_type)) } else { // initializing MatDesc takes a lot of time cusparseLtMatDescriptor_t mat_A, mat_B, mat_C; sp_mat_A_desc_map_[mark] = mat_A; sp_mat_B_desc_map_[mark] = mat_B; sp_mat_C_desc_map_[mark] = mat_C; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, &sp_mat_A_desc_map_[mark], num_A_rows, num_A_cols, lda, alignment, Atype_, order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit( &cusparselt_handle_, &sp_mat_B_desc_map_[mark], num_B_rows, num_B_cols, ldb, alignment, Btype_, order)) CHECK_CUSPARSE(cusparseLtDenseDescriptorInit( &cusparselt_handle_, &sp_mat_C_desc_map_[mark], num_C_rows, num_C_cols, ldc, alignment, Ctype_, order)) CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(&cusparselt_handle_, &matmul, opA, opB, &sp_mat_A_desc_map_[mark], &sp_mat_B_desc_map_[mark], &sp_mat_C_desc_map_[mark], &sp_mat_C_desc_map_[mark], compute_type)) } mu_->lock(); CHECK_CUSPARSE( cusparseLtMatmulAlgSelectionInit(&cusparselt_handle_, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) int alg = cublas_algo_map_->getSpAlgo(1, num_A_rows, num_B_cols, num_A_cols); CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( &cusparselt_handle_, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg))) size_t workspace_size; CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&cusparselt_handle_, &alg_sel, &workspace_size)) CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&cusparselt_handle_, &plan, &matmul, &alg_sel, workspace_size)) void* d_workspace = nullptr; int num_streams = 1; cudaStream_t streams[1] = {stream_}; CHECK_CUSPARSE( cusparseLtMatmul(&cusparselt_handle_, &plan, &_alpha, A, B, &_beta, C, C, d_workspace, streams, num_streams)) CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) sync_check_cuda_error(); mu_->unlock(); } size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) { // Get a compressed matrix size of shape (m, k) used in cusparselt. auto Atype_ = CUDA_R_16F; cusparseOrder_t order = CUSPARSE_ORDER_COL; unsigned alignment = 16; int num_A_rows = m; int num_A_cols = k; int lda = num_A_rows; cusparseLtMatDescriptor_t mat_A; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(&cusparselt_handle_, &mat_A, num_A_rows, num_A_cols, lda, alignment, Atype_, order, CUSPARSELT_SPARSITY_50_PERCENT)); size_t compressed_size = 0; CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &mat_A, &compressed_size)); return compressed_size; } void cublasMMWrapper::compressMatrix(const void* input, void* output, const int m, const int k) { cusparseOrder_t order = CUSPARSE_ORDER_COL; cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; cusparseLtMatDescriptor_t mat_A; unsigned alignment = 16; CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( &cusparselt_handle_, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT)) CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &mat_A, true, opA, input, output, stream_)) sync_check_cuda_error(); } bool cublasMMWrapper::isUseSparse(const int batch_count, const int m, const int n, const int k) { return cublas_algo_map_->isUseSparse(batch_count, m, n, k); } #endif std::pair cublasMMWrapper::findBestAlgo(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc, const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc, cudaStream_t stream) { //#if (CUBLAS_VERSION) <= 11601 FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); return {false, cublasLtMatmulAlgo_t{}}; /*#else size_t returnSize; int32_t pointer_mode; cublasLtMatmulDescGetAttribute( computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize); std::vector heuristics(200); cublasLtMatmulPreference_t preference; check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); check_cuda_error(cublasLtMatmulPreferenceInit(preference)); uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; check_cuda_error(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); #if (CUBLAS_VERSION) <= 12000 uint32_t pointer_mode_mask = 0; check_cuda_error(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); #endif int return_count = 0; auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, heuristics.size(), heuristics.data(), &return_count); heuristics.resize(return_count); std::map> algo_results; for (const auto& heuristic : heuristics) { cublasLtMatmulAlgo_t algo = heuristic.algo; int32_t algo_id; cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); cudaEvent_t start_event, stop_event; cudaEventCreate(&start_event); cudaEventCreate(&stop_event); float my_alpha = 1.0f; float my_beta = 0.0f; for (int i = 0; i < 11; i++) { float duration_ms; cudaEventRecord(start_event, stream); check_cuda_error(cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, &algo, cublas_workspace_, CUBLAS_WORKSPACE_SIZE, stream)); cudaEventRecord(stop_event, stream); cudaEventSynchronize(stop_event); cudaEventElapsedTime(&duration_ms, start_event, stop_event); algo_results[algo_id].push_back(duration_ms); } std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end()); } cublasLtMatmulHeuristicResult_t result; float best_time = INFINITY; for (const auto& heuristic : heuristics) { cublasLtMatmulAlgo_t algo = heuristic.algo; int32_t algo_id; cublasLtMatmulAlgoConfigGetAttribute(&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); const auto& results = algo_results[algo_id]; if (results.size() > 0 && results[5] < best_time) { best_time = results[5]; result = heuristic; } } return {best_time != INFINITY, result.algo}; #endif*/ } cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc) { size_t returnSize; MatrixLayout m_layout; FT_CHECK_WITH_INFO(false, "cublasLtMatrixLayoutGetAttribute is not support."); /* cublasLtMatrixLayoutGetAttribute( Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize); cublasLtMatrixLayoutGetAttribute( Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &std::get<1>(m_layout), sizeof(std::get<1>(m_layout)), &returnSize); cublasLtMatrixLayoutGetAttribute( Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize); cublasLtMatrixLayoutGetAttribute( Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize); */ return m_layout; } cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc, const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream) { cache_idx_t cache_idx{ computeDesc, {createMatrixLayout(Adesc), createMatrixLayout(Bdesc), createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}}; cublasLtMatmulAlgo_t algo_value; bool found_algo = false; if (algo == nullptr) { if (algo_cache.find(cache_idx) == algo_cache.end()) { auto result = findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, stream); if (result.first) { algo_cache[cache_idx] = result.second; algo_value = result.second; found_algo = true; } } else { algo_value = algo_cache[cache_idx]; found_algo = true; } } return cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, found_algo ? &algo_value : algo, workspace, workspaceSizeInBytes, stream); } void cublasMMWrapper::_Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb, void* C, const int ldc, const void* alpha, const int mode, const bool per_column_scaling) { /* mode: * - 0: int8 * int8 -> int32 -> int8 * - 1: int8 * int8 -> int32 -> int32 */ // #if (CUBLAS_VERSION) <= 11601 #if 1 FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); #else mu_->lock(); const auto op_a = CUBLAS_OP_T; const auto op_b = CUBLAS_OP_N; const auto dataType = CUDA_R_8I; const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I; const auto computeType = CUBLAS_COMPUTE_32I; const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I; const int batch_count = 1; const void* beta; int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(dataType)); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(dataType)); cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // -------------------------------------- // Create descriptors for the original matrices check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda)); check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb)); check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc)); check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType)); auto pointer_mode = CUBLASLT_POINTER_MODE_HOST; if (mode == 0) { pointer_mode = per_column_scaling ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST : CUBLASLT_POINTER_MODE_DEVICE; } check_cuda_error( cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(cublasOperation_t))); check_cuda_error( cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(cublasOperation_t))); check_cuda_error( cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &op_b, sizeof(cublasOperation_t))); check_cuda_error(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); const int32_t int_one = 1; const int32_t int_zero = 0; const float float_zero = 0; if (mode == 0) { beta = per_column_scaling ? &float_zero : NULL; } else { alpha = &int_one; beta = &int_zero; } cublasLtMatmulAlgo_t algo; void* workSpace = cublas_workspace_; int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; sync_check_cuda_error(); auto ret = cublasLtMatmulWrapper(cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, workSpace, workspaceSize, stream_); check_cuda_error(ret); sync_check_cuda_error(); cublasLtMatmulDescDestroy(operationDesc); cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Cdesc); sync_check_cuda_error(); mu_->unlock(); #endif } void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb, int8_t* C, const int ldc, const float* alpha, const bool per_column_scaling) { return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, per_column_scaling); } void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, const int8_t* A, const int lda, const int8_t* B, const int ldb, int32_t* C, const int ldc) { return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float*)nullptr, 1, false); } } // namespace turbomind