Commit 669dc816 authored by zhouxiang's avatar zhouxiang
Browse files

支持混精和半精切换能力

parent 5f83e392
...@@ -272,9 +272,19 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -272,9 +272,19 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
std::unique_ptr<cudaDeviceProp> cuda_device_prop_ptr(new cudaDeviceProp); std::unique_ptr<cudaDeviceProp> cuda_device_prop_ptr(new cudaDeviceProp);
ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id));
int hgemm_switch = 0;
const char* env_var_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH");
if (env_var_value_str != nullptr) {
hgemm_switch = std::stoi(env_var_value_str);
}
if (std::is_same<T, half>::value) { if (std::is_same<T, half>::value) {
// cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); if(hgemm_switch == 2){
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F); cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F);
}
else{
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
}
} }
else if (std::is_same<T, float>::value) { else if (std::is_same<T, float>::value) {
cublas_wrapper->setFP32GemmConfig(); cublas_wrapper->setFP32GemmConfig();
......
...@@ -40,6 +40,18 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, ...@@ -40,6 +40,18 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
if (allocator_ != nullptr) { if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false);
} }
// hgemm-switch 0:fp32r,1:fp16r-fp32r,2:fp16r ----xzhou 20240427
m_ihgemm_switch = 0;
const char* env_var_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH");
if (env_var_value_str != nullptr) {
m_ihgemm_switch = std::stoi(env_var_value_str);
}
m_ihgemm_switch_n = 16;
const char* env_n_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH_N");
if (env_n_value_str != nullptr) {
m_ihgemm_switch_n = std::stoi(env_n_value_str);
}
if(m_ihgemm_switch != 0) printf("hgemm_switch=%d, hgemm_switch_n_limit=%d\n", m_ihgemm_switch, m_ihgemm_switch_n);
} }
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
...@@ -113,6 +125,10 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -113,6 +125,10 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
mu_->lock(); mu_->lock();
// hgemm-switch ----xzhou 20240427
if(m_ihgemm_switch == 1 && (m == 5120 || m == 4096 || m == 12288 || m == 11008) && n <= m_ihgemm_switch_n && Atype == CUDA_R_16F){
computeType = CUDA_R_16F;
}
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
transa, transa,
transb, transb,
...@@ -172,7 +188,12 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -172,7 +188,12 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
mu_->lock(); mu_->lock();
// TODO: default cublas libs // TODO: default cublas libs
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; cudaDataType_t computeType = computeType_;
// hgemm-switch ----xzhou 20240427
if(m_ihgemm_switch == 1 && (m == 5120 || m == 4096 || m == 12288 || m == 11008) && n <= m_ihgemm_switch_n && Atype_ == CUDA_R_16F){
computeType = CUDA_R_16F;
}
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false;
int batch_count = 1; int batch_count = 1;
// fp32 use cublas as default // fp32 use cublas as default
...@@ -323,7 +344,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -323,7 +344,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
C, C,
Ctype_, Ctype_,
ldc, ldc,
computeType_, computeType,
static_cast<cublasGemmAlgo_t>(cublasAlgo))); static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error(); sync_check_cuda_error();
// } // }
...@@ -343,7 +364,7 @@ void cublasMMWrapper::setFP16GemmConfig() ...@@ -343,7 +364,7 @@ void cublasMMWrapper::setFP16GemmConfig()
Atype_ = CUDA_R_16F; Atype_ = CUDA_R_16F;
Btype_ = CUDA_R_16F; Btype_ = CUDA_R_16F;
Ctype_ = CUDA_R_16F; Ctype_ = CUDA_R_16F;
computeType_ = CUDA_R_16F; computeType_ = CUDA_R_32F;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
......
...@@ -48,6 +48,9 @@ protected: ...@@ -48,6 +48,9 @@ protected:
cublasAlgoMap* cublas_algo_map_; cublasAlgoMap* cublas_algo_map_;
std::mutex* mu_; std::mutex* mu_;
int m_ihgemm_switch;
int m_ihgemm_switch_n;
IAllocator* allocator_ = nullptr; IAllocator* allocator_ = nullptr;
void* cublas_workspace_ = nullptr; void* cublas_workspace_ = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment