Commit 1dc4b857 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

opt rms_norm

parent 23a7a73f
......@@ -157,6 +157,81 @@ static __global__ void __launch_bounds__(1024) rms_norm_f32(const float * x, flo
}
}
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
template <typename T, int VEC, int NUM_WARPS>
__inline__ __device__ T BlockReduceSumVEC(T& val, T* shared) {
#pragma unroll
for (int offset = 32; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, offset, 64); //64
}
if constexpr(1 < NUM_WARPS) {
const int tid = threadIdx.x;
const int lid = tid % 64;
const int wid = tid / 64;
if(lid == 0) {
shared[wid] = val;
}
__syncthreads();
if(wid == 0 && lid < NUM_WARPS) {
#pragma unroll
for (int offset = NUM_WARPS/2; offset > 0; offset >>= 1) {
shared[lid] += __shfl_xor_sync(0xffffffff, shared[lid], offset, 64); //64
}
val = shared[lid];
}
}
return val;
}
template <int block_size, int VEC = 4>
static __global__ void __launch_bounds__(1024) rms_norm_f32_opt1(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
constexpr int NUM_WARPS = block_size / 64;
__shared__ float lds_sum[NUM_WARPS*4];
__shared__ float sum_val;
float tmp = 0.0f; // partial sum for thread in warp floatx4
floatx4 xi_vec;
for (int col = tid*VEC; col < ncols; col += block_size*VEC) {
xi_vec = *(floatx4*)(x + row*ncols + col);
#pragma unroll
for(int i = 0; i < VEC; ++i)
{
tmp += xi_vec[i]*xi_vec[i];
}
}
tmp = BlockReduceSumVEC<float, VEC, NUM_WARPS>(tmp, lds_sum);
// tmp = __shfl_sync(0xffffffff, tmp, 0); //lds or shfl
if(tid == 0)
sum_val = rsqrtf(tmp / ncols + eps);
__syncthreads();
float scale = sum_val;
//重复利用寄存器访存
for (int col = tid*VEC; col < ncols; col += block_size*VEC) {
xi_vec = *(floatx4*)(x + row*ncols + col);
#pragma unroll
for(int i = 0; i < VEC; ++i){
xi_vec[i] = xi_vec[i] * scale;
}
*(floatx4*)(dst + row*ncols + col) = xi_vec;
}
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
......@@ -185,7 +260,7 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
rms_norm_f32_opt1<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment