#include #include #include #include #include #include #include #include "dispatch_utils.h" #include "reduction_utils.cuh" #ifndef USE_ROCM #include #include #else #include #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; #endif static inline bool get_env_(const char *env_var) { if (char *value = std::getenv(env_var)) { if (strcmp(value, "0") == 0) { return false; } return true; } return false; } static const bool use_old= get_env_("USE_VLLM_OLD_OP"); namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)input[blockIdx.x * hidden_size + idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion operators/constructors are not consistently implemented by HIP/CUDA, so a generic conversion via type casts cannot be implemented. Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. If true, the struct should be fully defined as shown in the examples below. */ template struct _typeConvert { static constexpr bool exists = false; }; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ template struct alignas(16) _f16Vec { /* Not theoretically necessary that width is a power of 2 but should almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; using T1 = typename Converter::hip_type; using T2 = typename Converter::packed_hip_type; T1 data[width]; __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { T2 temp{data[i], data[i + 1]}; temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; data[i + 1] = temp.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { T2 temp{data[i], data[i + 1]}; temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; data[i + 1] = temp.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; data[i + 1] = temp.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); } } return *this; } __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { #pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; } } return result; } }; /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); auto* __restrict__ weight_v = reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = input_v[id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; } /* Keep the following if-else block in sync with the calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); } else variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = residual_v[id]; temp *= s_variance; temp *= weight_v[idx]; input_v[id] = temp; } } /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); } else variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; input[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } } // namespace vllm template __inline__ __device__ T WarpReduceSum_NEW(T val) { #pragma unroll for (int offset = reducesize/2; offset > 0; offset >>= 1) { val += WARP_SHFL_DOWN(val, offset); } return val; } template __inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) { constexpr int share_size=block_size/C10_WARP_SIZE; val = WarpReduceSum_NEW(val); if constexpr(block_size==C10_WARP_SIZE) { return val; } else{ const int lid = threadIdx.x % C10_WARP_SIZE; const int wid = threadIdx.x / C10_WARP_SIZE; if (lid == 0&&wid(shared[lid]); } return val; } } template __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps) { constexpr int share_size=block_size/C10_WARP_SIZE; __shared__ T_ACC val_shared[share_size]; __shared__ T_ACC s_rstd; T_ACC val=0; int i=blockIdx.x; int j=threadIdx.x; int tcol=cols/Vec; using LoadT = at::native::memory::aligned_vector; scalar_t intput_vec[Vec]; scalar_t residual_vec[Vec]; T_ACC trstd; int idx = i * tcol + j; idx*=Vec; *(LoadT*)intput_vec = *(LoadT*)(input+idx); *(LoadT*)residual_vec = *(LoadT*)(residual+idx); if (j < tcol) { #pragma unroll for (int ii = 0; ii < Vec; ii++) { residual_vec[ii]+=intput_vec[ii]; val += static_cast(residual_vec[ii])*static_cast(residual_vec[ii]); } } val = BlockReduceSum_NEW(val,val_shared); if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps); __syncthreads(); trstd=s_rstd; if (j < tcol) { #pragma unroll for(int ii=0;ii(residual_vec[ii]) *trstd* static_cast(gamma[jj]); } *(LoadT*)(residual+idx)=*(LoadT*)residual_vec; *(LoadT*)(input+idx)=*(LoadT*)intput_vec; } } template __global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps) { constexpr int share_size=block_size/C10_WARP_SIZE; __shared__ T_ACC val_shared[share_size]; __shared__ T_ACC s_rstd; T_ACC val=0; int i=blockIdx.x; int j=threadIdx.x; int tcol=cols/Vec; using LoadT = at::native::memory::aligned_vector; scalar_t intput_vec[Vec]; T_ACC trstd; int idx = i * tcol + j; idx*=Vec; *(LoadT*)intput_vec = *(LoadT*)(input+idx); if (j < tcol) { #pragma unroll for (int ii = 0; ii < Vec; ii++) { val += static_cast(intput_vec[ii])*static_cast(intput_vec[ii]); } } val = BlockReduceSum_NEW(val,val_shared); if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps); __syncthreads(); trstd=s_rstd; if (j < tcol) { #pragma unroll for(int ii=0;ii(intput_vec[ii]) *trstd* static_cast(gamma[jj]); } *(LoadT*)(output+idx)=*(LoadT*)intput_vec; } } void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto inp_ptr = reinterpret_cast(input.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0; if(!use_old&&hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){ AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "fused_add_rms_norm_kernel", [&] { using T_ACC = at::acc_type; T_ACC eps = epsilon; scalar_t* self_data = input.data_ptr(); scalar_t* out_data =out.data_ptr(); scalar_t* weight_data=weight.data_ptr(); if (hidden_size<=1024){ fused_rms_kernel_eval<<>>(self_data,out_data,weight_data,hidden_size,eps); } else if(hidden_size<=2048){ fused_rms_kernel_eval<<>>(self_data,out_data,weight_data,hidden_size,eps); } else if(hidden_size<=4096){ if(num_tokens>1200){ fused_rms_kernel_eval<<>>(self_data,out_data,weight_data,hidden_size,eps); } else{ fused_rms_kernel_eval<<>>(self_data,out_data,weight_data,hidden_size,eps); } } else if(hidden_size<=8192){ fused_rms_kernel_eval<<>>(self_data,out_data,weight_data,hidden_size,eps); } else{ fused_rms_kernel_eval<<>>(self_data,out_data,weight_data,hidden_size,eps); } }); } else{ dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { vllm::rms_norm_kernel<<>>( out.data_ptr(), input.data_ptr(), weight.data_ptr(), epsilon, num_tokens, hidden_size); }); } } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ vllm::fused_add_rms_norm_kernel \ <<>>(input.data_ptr(), \ residual.data_ptr(), \ weight.data_ptr(), epsilon, \ num_tokens, hidden_size); \ }); void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if(!use_old&&hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){ AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "fused_add_rms_norm_kernel", [&] { using T_ACC = at::acc_type; T_ACC eps = epsilon; scalar_t* self_data = input.data_ptr(); scalar_t* other_data =residual.data_ptr(); scalar_t* weight_data=weight.data_ptr(); if (hidden_size<=1024){ fused_add_rms_kernel_eval<<>>(self_data,other_data,weight_data,hidden_size,eps); } else if(hidden_size<=2048){ fused_add_rms_kernel_eval<<>>(self_data,other_data,weight_data,hidden_size,eps); } else if(hidden_size<=4096){ if(num_tokens>1200){ fused_add_rms_kernel_eval<<>>(self_data,other_data,weight_data,hidden_size,eps); } else{ fused_add_rms_kernel_eval<<>>(self_data,other_data,weight_data,hidden_size,eps); } } else if(hidden_size<=8192){ fused_add_rms_kernel_eval<<>>(self_data,other_data,weight_data,hidden_size,eps); } else{ fused_add_rms_kernel_eval<<>>(self_data,other_data,weight_data,hidden_size,eps); } }); } else{ dim3 grid(num_tokens); /* This kernel is memory-latency bound in many scenarios. When num_tokens is large, a smaller block size allows for increased block occupancy on CUs and better latency hiding on global mem ops. */ const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); /*If the tensor types are FP16/BF16, try to use the optimized kernel with packed + vectorized ops. Max optimization is achieved with a width-8 vector of FP16/BF16s since we can load at most 128 bits at once in a global memory op. However, this requires each tensor's data to be aligned to 16 bytes. */ if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); } } }