Commit 785f450d authored by zhangshao's avatar zhangshao
Browse files

修复rmsnorm bug,增加USE_VLLM_OLD_OP标志使用原版rmsnorm

parent 1c5e7720
......@@ -17,7 +17,16 @@
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
static inline bool get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
if (strcmp(value, "0") == 0) {
return false;
}
return true;
}
return false;
}
static const bool use_old= get_env_("USE_VLLM_OLD_OP");
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
......@@ -332,7 +341,6 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
......@@ -341,15 +349,18 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
......@@ -357,6 +368,7 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
......@@ -369,27 +381,30 @@ __global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t*
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
}
void rms_norm(torch::Tensor& out, // [..., hidden_size]
......@@ -403,7 +418,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
if(!use_old&&hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
......@@ -473,7 +488,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
if(!use_old&&hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment