/* * This file contains the CUDA kernels for the fused quantized layernorm. * The kernels correspond to the kernels in layernorm_kernels.cu, except they * also produce quantized output directly. * Currently, only static fp8 quantization is supported. */ #include "type_convert.cuh" #include "quantization/w8a8/fp8/common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" #include "libtorch_stable/quantization/vectorization_utils.cuh" #include #include namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const int input_stride, const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; const scalar_t* input_row = input + blockIdx.x * input_stride; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { float x = static_cast(vec.val[i]); variance += x * x; } }; auto scalar_op = [&variance](const scalar_t& val) { float x = static_cast(val); variance += x * x; }; vllm::vectorize_read_with_alignment( input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); // invert scale to avoid division float const scale_inv = 1.0f / *scale; auto* v_in = reinterpret_cast*>(input_row); auto* v_w = reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) { vec_n_t src1 = v_in[idx]; vec_n_t src2 = v_w[idx]; #pragma unroll for (int j = 0; j < VEC_SIZE; j++) { float x = static_cast(src1.val[j]); float w = static_cast(src2.val[j]); // Round normalized result through scalar_t to match the precision of the // unfused composite (rms_norm writes scalar_t, then // static_scaled_fp8_quant re-loads it as float before FP8 conversion). // Without this round, the fused path is strictly more accurate and // disagrees with the composite at exact E4M3 quantization tie boundaries. scalar_t out_norm = static_cast(x * s_variance * w); out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] = scaled_fp8_conversion(static_cast(out_norm), scale_inv); } } } /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] const int input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; const int vec_input_stride = input_stride / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); auto* __restrict__ weight_v = reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int stride_id = blockIdx.x * vec_input_stride + idx; int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = input_v[stride_id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); // invert scale to avoid division float const scale_inv = 1.0f / *scale; for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec res = residual_v[id]; _f16Vec w = weight_v[idx]; using Converter = _typeConvert; using HipT = typename Converter::hip_type; #pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(res.data[i]); float wf = Converter::convert(w.data[i]); // See note in rms_norm_static_fp8_quant_kernel: round through scalar_t // to match the unfused composite path at FP8 boundaries. We use the // backend's hip_type for the intermediate since c10::Half/BFloat16 has // ambiguous conversions on CUDA and no implicit conversion on ROCm. HipT out_norm_h = Converter::convert(x * s_variance * wf); out[id * width + i] = scaled_fp8_conversion( Converter::convert(out_norm_h), scale_inv); } } } /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] const int input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * input_stride + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); // invert scale to avoid division float const scale_inv = 1.0f / *scale; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; float w = (float)weight[idx]; // See note in rms_norm_static_fp8_quant_kernel: round through scalar_t // to match the unfused composite path at FP8 boundaries. scalar_t out_norm = static_cast(x * s_variance * w); out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion( static_cast(out_norm), scale_inv); } } } // namespace vllm void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] torch::Tensor& scale, // [1] double epsilon) { TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; // For large num_tokens, use smaller blocks to increase SM concurrency. const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { const int calculated_vec_size = std::gcd(16 / sizeof(scalar_t), hidden_size); const int block_size = std::min(hidden_size / calculated_vec_size, max_block_size); dim3 block(block_size); VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { vllm::rms_norm_static_fp8_quant_kernel <<>>( out.data_ptr(), input.data_ptr(), input_stride, weight.data_ptr(), scale.data_ptr(), epsilon, num_tokens, hidden_size); }); }); }); } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \ VLLM_DISPATCH_FP8_TYPES( \ out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ vllm::fused_add_rms_norm_static_fp8_quant_kernel \ <<>>( \ out.data_ptr(), input.data_ptr(), \ input_stride, residual.data_ptr(), \ weight.data_ptr(), scale.data_ptr(), \ epsilon, num_tokens, hidden_size); \ }); \ }); void fused_add_rms_norm_static_fp8_quant( torch::Tensor& out, // [..., hidden_size], torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] torch::Tensor& scale, // [1] double epsilon) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.scalar_type() == input.scalar_type()); TORCH_CHECK(weight.scalar_type() == input.scalar_type()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); /* This kernel is memory-latency bound in many scenarios. When num_tokens is large, a smaller block size allows for increased block occupancy on CUs and better latency hiding on global mem ops. */ const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); /*If the tensor types are FP16/BF16, try to use the optimized kernel with packed + vectorized ops. Max optimization is achieved with a width-8 vector of FP16/BF16s since we can load at most 128 bits at once in a global memory op. However, this requires each tensor's data to be aligned to 16 bytes. */ auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); } }