Commit d8ae62c7 authored by zhangshao's avatar zhangshao
Browse files

Update layernorm_kernels_opt.cu

parent bf278a88
......@@ -338,9 +338,9 @@ __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scal
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
......@@ -377,8 +377,8 @@ __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t*
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment