Commit 785f450d authored by zhangshao's avatar zhangshao
Browse files

修复rmsnorm bug,增加USE_VLLM_OLD_OP标志使用原版rmsnorm

parent 1c5e7720
...@@ -17,7 +17,16 @@ ...@@ -17,7 +17,16 @@
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
static inline bool get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
if (strcmp(value, "0") == 0) {
return false;
}
return true;
}
return false;
}
static const bool use_old= get_env_("USE_VLLM_OLD_OP");
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
...@@ -332,7 +341,6 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca ...@@ -332,7 +341,6 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
int i=blockIdx.x; int i=blockIdx.x;
int j=threadIdx.x; int j=threadIdx.x;
int tcol=cols/Vec; int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>; using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec]; scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec]; scalar_t residual_vec[Vec];
...@@ -341,22 +349,26 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca ...@@ -341,22 +349,26 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
idx*=Vec; idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx); *(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx); *(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll if (j < tcol) {
for (int ii = 0; ii < Vec; ii++) { #pragma unroll
residual_vec[ii]+=intput_vec[ii]; for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]); residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
} }
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared); val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps); if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads(); __syncthreads();
trstd=s_rstd; trstd=s_rstd;
#pragma unroll if (j < tcol) {
for(int ii=0;ii<Vec;ii++){ #pragma unroll
int jj=j*Vec+ii; for(int ii=0;ii<Vec;ii++){
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]); int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
} }
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
} }
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512> template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
...@@ -369,27 +381,30 @@ __global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* ...@@ -369,27 +381,30 @@ __global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t*
int i=blockIdx.x; int i=blockIdx.x;
int j=threadIdx.x; int j=threadIdx.x;
int tcol=cols/Vec; int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>; using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec]; scalar_t intput_vec[Vec];
T_ACC trstd; T_ACC trstd;
int idx = i * tcol + j; int idx = i * tcol + j;
idx*=Vec; idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx); *(LoadT*)intput_vec = *(LoadT*)(input+idx);
#pragma unroll if (j < tcol) {
for (int ii = 0; ii < Vec; ii++) { #pragma unroll
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]); for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
} }
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared); val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps); if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads(); __syncthreads();
trstd=s_rstd; trstd=s_rstd;
#pragma unroll if (j < tcol) {
for(int ii=0;ii<Vec;ii++){ #pragma unroll
int jj=j*Vec+ii; for(int ii=0;ii<Vec;ii++){
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]); int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
} }
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
} }
void rms_norm(torch::Tensor& out, // [..., hidden_size] void rms_norm(torch::Tensor& out, // [..., hidden_size]
...@@ -403,7 +418,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] ...@@ -403,7 +418,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0; bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){ if(!use_old&&hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Half,
at::ScalarType::BFloat16, at::ScalarType::BFloat16,
...@@ -473,7 +488,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] ...@@ -473,7 +488,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){ if(!use_old&&hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Half,
at::ScalarType::BFloat16, at::ScalarType::BFloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment