Commit a0b44ca8 authored by zhouxiang's avatar zhouxiang
Browse files

默认修改为混精的gemm计算

parent fe851fbc
...@@ -64,7 +64,7 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -64,7 +64,7 @@ def get_version_add(sha: Optional[str] = None) -> str:
lines=[] lines=[]
with open(add_version_path, 'r',encoding='utf-8') as file: with open(add_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines() lines = file.readlines()
lines[2] = "__dcu_version__ = '0.1.0+{}'\n".format(version) lines[2] = "__dcu_version__ = '0.2.6+{}'\n".format(version)
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
file.writelines(lines) file.writelines(lines)
file.close() file.close()
......
...@@ -272,9 +272,19 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -272,9 +272,19 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
std::unique_ptr<cudaDeviceProp> cuda_device_prop_ptr(new cudaDeviceProp); std::unique_ptr<cudaDeviceProp> cuda_device_prop_ptr(new cudaDeviceProp);
ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id));
int hgemm_switch = 0;
const char* env_var_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH");
if (env_var_value_str != nullptr) {
hgemm_switch = std::stoi(env_var_value_str);
}
if (std::is_same<T, half>::value) { if (std::is_same<T, half>::value) {
// cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); if(hgemm_switch == 2){
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F); cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F);
}
else{
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
}
} }
else if (std::is_same<T, float>::value) { else if (std::is_same<T, float>::value) {
cublas_wrapper->setFP32GemmConfig(); cublas_wrapper->setFP32GemmConfig();
......
...@@ -40,6 +40,18 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, ...@@ -40,6 +40,18 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
if (allocator_ != nullptr) { if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false);
} }
# hgemm-switch 0:fp32r,1:fp16r-fp32r,2:fp16r ----xzhou 20240427
m_ihgemm_switch = 0;
const char* env_var_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH");
if (env_var_value_str != nullptr) {
m_ihgemm_switch = std::stoi(env_var_value_str);
}
m_ihgemm_switch_n = 16;
const char* env_n_value_str = std::getenv("LMDEPLOY_HGEMM_SWITCH_N");
if (env_n_value_str != nullptr) {
m_ihgemm_switch_n = std::stoi(env_n_value_str);
}
if(m_ihgemm_switch == 0) printf("hgemm_switch=%d, hgemm_switch_n_limit=%d\n", m_ihgemm_switch, m_ihgemm_switch_n);
} }
#ifdef SPARSITY_ENABLED #ifdef SPARSITY_ENABLED
...@@ -113,6 +125,10 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -113,6 +125,10 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
mu_->lock(); mu_->lock();
# hgemm-switch xzhou 20240427
if(m_ihgemm_switch == 1 && (m == 5120 || m == 4096 || m == 12288 || m == 11008) && n <= m_ihgemm_switch_n && Atype == CUDA_R_16F){
computeType = CUDA_R_16F;
}
check_cuda_error(cublasGemmEx(cublas_handle_, check_cuda_error(cublasGemmEx(cublas_handle_,
transa, transa,
transb, transb,
...@@ -172,7 +188,12 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -172,7 +188,12 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
mu_->lock(); mu_->lock();
// TODO: default cublas libs // TODO: default cublas libs
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; cudaDataType_t computeType = computeType_;
# hgemm-switch xzhou 20240427
if(m_ihgemm_switch == 1 && (m == 5120 || m == 4096 || m == 12288 || m == 11008) && n <= m_ihgemm_switch_n && Atype_ == CUDA_R_16F){
computeType = CUDA_R_16F;
}
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false;
int batch_count = 1; int batch_count = 1;
// fp32 use cublas as default // fp32 use cublas as default
...@@ -323,7 +344,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -323,7 +344,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
C, C,
Ctype_, Ctype_,
ldc, ldc,
computeType_, computeType,
static_cast<cublasGemmAlgo_t>(cublasAlgo))); static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error(); sync_check_cuda_error();
// } // }
...@@ -343,7 +364,7 @@ void cublasMMWrapper::setFP16GemmConfig() ...@@ -343,7 +364,7 @@ void cublasMMWrapper::setFP16GemmConfig()
Atype_ = CUDA_R_16F; Atype_ = CUDA_R_16F;
Btype_ = CUDA_R_16F; Btype_ = CUDA_R_16F;
Ctype_ = CUDA_R_16F; Ctype_ = CUDA_R_16F;
computeType_ = CUDA_R_16F; computeType_ = CUDA_R_32F;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
......
...@@ -48,6 +48,9 @@ protected: ...@@ -48,6 +48,9 @@ protected:
cublasAlgoMap* cublas_algo_map_; cublasAlgoMap* cublas_algo_map_;
std::mutex* mu_; std::mutex* mu_;
int m_ihgemm_switch;
int m_ihgemm_switch_n;
IAllocator* allocator_ = nullptr; IAllocator* allocator_ = nullptr;
void* cublas_workspace_ = nullptr; void* cublas_workspace_ = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment