Unverified Commit 0380ca82 authored by Yuzhen Zhou's avatar Yuzhen Zhou Committed by GitHub
Browse files

Add Batch‑Invariant RMSNorm (#12144)

parent ec92b0ce
......@@ -9,6 +9,7 @@ from .batch_invariant_ops import (
log_softmax,
matmul_persistent,
mean_dim,
rms_norm_batch_invariant,
set_batch_invariant_mode,
)
......@@ -24,4 +25,5 @@ __all__ = [
"mean_dim",
"get_batch_invariant_attention_block_size",
"AttentionBlockSize",
"rms_norm_batch_invariant",
]
......@@ -579,6 +579,126 @@ def bmm_batch_invariant(a, b, *, out=None):
)
@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx = tl.program_id(0).to(tl.int64)
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq = tl.zeros([1], dtype=tl.float32)
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
# Convert to float32 for accumulation to prevent overflow
vals_f32 = vals.to(tl.float32)
sq_vals = vals_f32 * vals_f32
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
# Step 2: Compute RMS (root mean square) in float32
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms
# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
# Compute in float32 then convert back to input dtype
vals_f32 = vals.to(tl.float32)
weight_f32 = weight.to(tl.float32)
output_f32 = vals_f32 * inv_rms * weight_f32
output = output_f32.to(vals.dtype)
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def rms_norm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input.shape[-1] == weight.shape[0], (
f"Input last dimension ({input.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})"
)
# Flatten all dimensions except the last one
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d)
BLOCK_SIZE = 1024
grid = (n_rows,)
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)
def rms_norm_batch_invariant(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return rms_norm(input, weight, eps=eps)
_batch_invariant_MODE = False
_batch_invariant_LIB = None
_original_torch_bmm = None
......
......@@ -20,7 +20,12 @@ import torch
import torch.nn as nn
from packaging.version import Version
from sglang.srt.batch_invariant_ops import (
is_batch_invariant_mode_enabled,
rms_norm_batch_invariant,
)
from sglang.srt.custom_op import CustomOp
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -90,8 +95,6 @@ class RMSNorm(CustomOp):
)
if _use_aiter:
self._forward_method = self.forward_aiter
if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
self._forward_method = self.forward_native
def forward_cuda(
self,
......@@ -100,6 +103,17 @@ class RMSNorm(CustomOp):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if is_batch_invariant_mode_enabled():
if (
residual is not None
or get_global_server_args().rl_on_policy_target == "fsdp"
):
return self.forward_native(x, residual)
return rms_norm_batch_invariant(
x,
self.weight.data,
self.variance_epsilon,
)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment