// SPDX-License-Identifier: MIT #include "py_itfs_common.h" #include #include #include #include "rmsnorm2d_fwd.hpp" void fused_add_rms_norm_out(torch::Tensor& out, torch::Tensor& input, torch::Tensor& residual_in, torch::Tensor& residual_out, torch::Tensor& weight, double epsilon); namespace { bool is_dense_row_major_2d(const torch::Tensor& tensor) { return tensor.dim() == 2 && tensor.stride(-1) == 1 && tensor.stride(0) == tensor.size(-1); } bool is_dense_last_dim(const torch::Tensor& tensor) { return tensor.dim() >= 1 && tensor.stride(-1) == 1 && tensor.is_contiguous(); } } // namespace void rmsnorm2d( torch::Tensor& out, // [m, n] torch::Tensor& input, // [m, n] torch::Tensor& weight, // [1, n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(dtype); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = -1; int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 0, // fused_add 0, // fused_quant }, {input.data_ptr(), nullptr, // p_x_residual nullptr, // p_x_scale weight.data_ptr(), out.data_ptr(), nullptr, // p_y_residual nullptr, // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } torch::Tensor rmsnorm2d( torch::Tensor& input, // [m, n] torch::Tensor& weight, // [1, n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { torch::Tensor out = torch::empty_like(input); rmsnorm2d(out, input, weight, epsilon); return out; } void rmsnorm2d_with_add( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& residual_in, // [m ,n] torch::Tensor& residual_out, // [m ,n] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); TORCH_CHECK(out.dtype() == dtype && residual_in.dtype() == dtype && residual_out.dtype() == dtype && weight.dtype() == dtype, "rmsnorm2d_with_add expects input/output/residual/weight to share the same dtype"); TORCH_CHECK(input.sizes() == out.sizes() && input.sizes() == residual_in.sizes() && input.sizes() == residual_out.sizes(), "rmsnorm2d_with_add expects input/out/residual tensors to have the same shape"); std::string dtype_str = torchDTypeToStr(input.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = residual_in.stride(0); int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); // 1. For bf16, we can also choose the vllm-like solution. // const bool can_use_vllm_bf16_bypass = dtype == torch::kBFloat16 && is_dense_row_major_2d(input) && // is_dense_row_major_2d(out) && // is_dense_row_major_2d(residual_in) && // is_dense_row_major_2d(residual_out) && // is_dense_last_dim(weight); // if(can_use_vllm_bf16_bypass) // { // fused_add_rms_norm_out(out, input, residual_in, residual_out, weight, epsilon); // return; // } // 2. CK solution rmsnorm2d_fwd({dtype_str, // input precision dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 1, // fused_add 0, // fused_quant }, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual nullptr, // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y residual_out.data_ptr(), // p_y_residual nullptr, // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_smoothquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& xscale, // [1 ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string xscale_dtype_str = torchDTypeToStr(xscale.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = -1; int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision xscale_dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 0, // fused_add 1, // fused_quant }, {input.data_ptr(), // p_x nullptr, // p_x_residual xscale.data_ptr(), // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y nullptr, // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_add_smoothquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& residual_in, // [m ,n] torch::Tensor& residual_out, // [m ,n] torch::Tensor& xscale, // [1 ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon, std::optional out_before_quant) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string xscale_dtype_str = torchDTypeToStr(xscale.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = residual_in.stride(0); int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision xscale_dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, out_before_quant.has_value(), // save_unquant 1, // fused_add 1, // fused_quant }, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual xscale.data_ptr(), // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y residual_out.data_ptr(), // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms out_before_quant.has_value() ? out_before_quant.value().data_ptr() : nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_dynamicquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = -1; int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 0, // fused_add 2, // fused_quant }, {input.data_ptr(), // p_x nullptr, // p_x_residual nullptr, // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y nullptr, // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } void rmsnorm2d_with_add_dynamicquant( torch::Tensor& out, // [m ,n] torch::Tensor& input, // [m ,n] torch::Tensor& residual_in, // [m ,n] torch::Tensor& residual_out, // [m ,n] torch::Tensor& yscale, // [m ,1] torch::Tensor& weight, // [1 ,n] double epsilon) // 0: Use default RMSNorm; 1: Use T5-like implementation { auto dtype = input.dtype(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, "ck rmsnorm2d only support fp16 and bf16 data type"); std::string dtype_str = torchDTypeToStr(input.dtype()); std::string out_dtype_str = torchDTypeToStr(out.dtype()); std::string yscale_dtype_str = torchDTypeToStr(yscale.dtype()); int n = input.size(-1); int m = input.numel() / n; int stride = input.stride(0); int xr_stride = residual_in.stride(0); int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision dtype_str, // x-scale, used for [1*N] input smooth quant yscale_dtype_str, // y-scale, used for [M*1] output for next layer SaveRms, false, // save_unquant 1, // fused_add 2, // fused_quant }, {input.data_ptr(), // p_x residual_in.data_ptr(), // p_x_residual nullptr, // p_x_scale weight.data_ptr(), // p_gamma out.data_ptr(), // p_y residual_out.data_ptr(), // p_y_residual yscale.data_ptr(), // p_y_scale nullptr, // p_invRms nullptr, // p_y_unquant static_cast(epsilon), m, n, stride, xr_stride, y_stride, yr_stride}, {stream}); } // ============================================================================ // head_rms_norm: per-head RMS normalization // Applies RMS normalization to each head independently. // input: [num_tokens, num_heads * head_dim] // weight: [num_heads * head_dim] // ============================================================================ template __device__ T hip_block_reduce_sum(T val) { __shared__ T sdata[1024]; int tid = threadIdx.x; sdata[tid] = val; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) sdata[tid] += sdata[tid + s]; __syncthreads(); } return sdata[0]; } template __global__ void head_rms_norm_kernel( scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const scalar_t* __restrict__ weight, const float epsilon, const int num_tokens, const int num_heads, const int head_dim) { const int token_idx = blockIdx.x; const int head_idx = blockIdx.y; if (token_idx >= num_tokens || head_idx >= num_heads) return; const int hidden_size = num_heads * head_dim; const int base_offset = token_idx * hidden_size + head_idx * head_dim; const int weight_offset = head_idx * head_dim; // Compute variance for this head float variance = 0.0f; for (int d = threadIdx.x; d < head_dim; d += blockDim.x) { float x = static_cast(input[base_offset + d]); variance += x * x; } // Block reduce variance = hip_block_reduce_sum(variance); __shared__ float s_rms; if (threadIdx.x == 0) { s_rms = rsqrtf(variance / head_dim + epsilon); } __syncthreads(); // Apply normalization and weight float rms = s_rms; for (int d = threadIdx.x; d < head_dim; d += blockDim.x) { float x = static_cast(input[base_offset + d]); float w = static_cast(weight[weight_offset + d]); out[base_offset + d] = static_cast(x * rms * w); } } torch::Tensor head_rms_norm( torch::Tensor& input, // [num_tokens, num_heads * head_dim] torch::Tensor& weight, // [num_heads * head_dim] double epsilon, int64_t norm_head_dim) // head_dim { auto dtype = input.scalar_type(); TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16 || dtype == torch::kFloat32, "head_rms_norm supports fp16, bf16, and fp32 data types"); TORCH_CHECK(weight.scalar_type() == dtype, "head_rms_norm expects input and weight to have the same dtype"); TORCH_CHECK(input.dim() >= 2, "head_rms_norm expects input with at least 2 dimensions"); int hidden_size = input.size(-1); int num_heads = hidden_size / norm_head_dim; TORCH_CHECK(hidden_size % norm_head_dim == 0, "hidden_size (", hidden_size, ") must be divisible by norm_head_dim (", norm_head_dim, ")"); TORCH_CHECK(weight.size(-1) == hidden_size, "weight last dim (", weight.size(-1), ") must match input last dim (", hidden_size, ")"); int num_tokens = input.numel() / hidden_size; torch::Tensor out = torch::empty_like(input); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); const hipStream_t stream = at::hip::getCurrentHIPStream(); const int threads = std::min(static_cast(norm_head_dim), 256); dim3 grid(num_tokens, num_heads); AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "head_rms_norm", ([&] { head_rms_norm_kernel<<>>( out.data_ptr(), input.data_ptr(), weight.data_ptr(), static_cast(epsilon), num_tokens, num_heads, static_cast(norm_head_dim)); })); return out; }