/** * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file * * MIT License * * Copyright (c) 2023-2024 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #include "norm.cuh" template static __global__ void __launch_bounds__(1024) norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; float2 mean_var = make_float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { const float xi = x[row*ncols + col]; mean_var.x += xi; mean_var.y += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var); if (block_size > WARP_SIZE) { __shared__ float2 s_sum[32]; int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = mean_var; } __syncthreads(); mean_var = s_sum[lane_id]; mean_var = warp_reduce_sum(mean_var); } const float mean = mean_var.x / ncols; const float var = mean_var.y / ncols - mean * mean; const float inv_std = rsqrtf(var + eps); for (int col = tid; col < ncols; col += block_size) { dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; } } template static __global__ void __launch_bounds__(1024) group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx // threadIdx.x: block_size idx int start = blockIdx.x * group_size; int end = start + group_size; start += threadIdx.x; if (end >= ne_elements) { end = ne_elements; } float tmp = 0.0f; // partial sum for thread in warp for (int j = start; j < end; j += block_size) { tmp += x[j]; } tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); } float mean = tmp / group_size; tmp = 0.0f; for (int j = start; j < end; j += block_size) { float xi = x[j] - mean; dst[j] = xi; tmp += xi * xi; } tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); } float variance = tmp / group_size; float scale = rsqrtf(variance + eps); for (int j = start; j < end; j += block_size) { dst[j] *= scale; } } template static __global__ void __launch_bounds__(1024) rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { const float xi = x[row*ncols + col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); } const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { dst[row*ncols + col] = scale * x[row*ncols + col]; } } using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; template __inline__ __device__ T BlockReduceSumVEC(T& val, T* shared) { #pragma unroll for (int offset = 32; offset > 0; offset >>= 1) { val += __shfl_xor_sync(0xffffffff, val, offset, 64); //64 } if constexpr(1 < NUM_WARPS) { const int tid = threadIdx.x; const int lid = tid % 64; const int wid = tid / 64; if(lid == 0) { shared[wid] = val; } __syncthreads(); if(wid == 0 && lid < NUM_WARPS) { #pragma unroll for (int offset = NUM_WARPS/2; offset > 0; offset >>= 1) { shared[lid] += __shfl_xor_sync(0xffffffff, shared[lid], offset, 64); //64 } val = shared[lid]; } } return val; } template static __global__ void __launch_bounds__(1024) rms_norm_f32_opt1(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; constexpr int NUM_WARPS = block_size / 64; __shared__ float lds_sum[NUM_WARPS*4]; __shared__ float sum_val; float tmp = 0.0f; // partial sum for thread in warp floatx4 floatx4 xi_vec; for (int col = tid*VEC; col < ncols; col += block_size*VEC) { xi_vec = *(floatx4*)(x + row*ncols + col); #pragma unroll for(int i = 0; i < VEC; ++i) { tmp += xi_vec[i]*xi_vec[i]; } } tmp = BlockReduceSumVEC(tmp, lds_sum); // tmp = __shfl_sync(0xffffffff, tmp, 0); //lds or shfl if(tid == 0) sum_val = rsqrtf(tmp / ncols + eps); __syncthreads(); float scale = sum_val; //重复利用寄存器访存 for (int col = tid*VEC; col < ncols; col += block_size*VEC) { xi_vec = *(floatx4*)(x + row*ncols + col); #pragma unroll for(int i = 0; i < VEC; ++i){ xi_vec[i] = xi_vec[i] * scale; } *(floatx4*)(dst + row*ncols + col) = xi_vec; } } static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); norm_f32<1024><<>>(x, dst, ncols, eps); } } static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { if (group_size < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); } else { const dim3 block_dims(1024, 1, 1); group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); } } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32_opt1<1024><<>>(x, dst, ncols, eps); } } void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); float eps; memcpy(&eps, dst->op_params, sizeof(float)); norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); int num_groups = dst->op_params[0]; float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); } void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); float eps; memcpy(&eps, dst->op_params, sizeof(float)); rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); }