Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0380ca82
Unverified
Commit
0380ca82
authored
Oct 29, 2025
by
Yuzhen Zhou
Committed by
GitHub
Oct 28, 2025
Browse files
Add Batch‑Invariant RMSNorm (#12144)
parent
ec92b0ce
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
138 additions
and
2 deletions
+138
-2
python/sglang/srt/batch_invariant_ops/__init__.py
python/sglang/srt/batch_invariant_ops/__init__.py
+2
-0
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+120
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+16
-2
No files found.
python/sglang/srt/batch_invariant_ops/__init__.py
View file @
0380ca82
...
...
@@ -9,6 +9,7 @@ from .batch_invariant_ops import (
log_softmax
,
matmul_persistent
,
mean_dim
,
rms_norm_batch_invariant
,
set_batch_invariant_mode
,
)
...
...
@@ -24,4 +25,5 @@ __all__ = [
"mean_dim"
,
"get_batch_invariant_attention_block_size"
,
"AttentionBlockSize"
,
"rms_norm_batch_invariant"
,
]
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
0380ca82
...
...
@@ -579,6 +579,126 @@ def bmm_batch_invariant(a, b, *, out=None):
)
@
triton
.
jit
def
_rms_norm_kernel
(
input_ptr
,
weight_ptr
,
output_ptr
,
input_row_stride
,
output_row_stride
,
n_cols
,
eps
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
row_start_ptr
=
input_ptr
+
row_idx
*
input_row_stride
output_row_start_ptr
=
output_ptr
+
row_idx
*
output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq
=
tl
.
zeros
([
1
],
dtype
=
tl
.
float32
)
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=
0.0
)
# Convert to float32 for accumulation to prevent overflow
vals_f32
=
vals
.
to
(
tl
.
float32
)
sq_vals
=
vals_f32
*
vals_f32
sum_sq
+=
tl
.
sum
(
tl
.
where
(
mask
,
sq_vals
,
0.0
))
# Step 2: Compute RMS (root mean square) in float32
mean_sq
=
sum_sq
/
n_cols
rms
=
tl
.
sqrt
(
mean_sq
+
eps
)
inv_rms
=
1.0
/
rms
# Step 3: Normalize and apply weight
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=
0.0
)
weight
=
tl
.
load
(
weight_ptr
+
col_idx
,
mask
=
mask
,
other
=
1.0
)
# Compute in float32 then convert back to input dtype
vals_f32
=
vals
.
to
(
tl
.
float32
)
weight_f32
=
weight
.
to
(
tl
.
float32
)
output_f32
=
vals_f32
*
inv_rms
*
weight_f32
output
=
output_f32
.
to
(
vals
.
dtype
)
tl
.
store
(
output_row_start_ptr
+
col_idx
,
output
,
mask
=
mask
)
def
rms_norm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""
Compute RMS normalization using Triton kernel.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert
weight
.
dim
()
==
1
,
"Weight must be 1-dimensional"
assert
input
.
shape
[
-
1
]
==
weight
.
shape
[
0
],
(
f
"Input last dimension (
{
input
.
shape
[
-
1
]
}
) must match "
f
"weight dimension (
{
weight
.
shape
[
0
]
}
)"
)
# Flatten all dimensions except the last one
original_shape
=
input
.
shape
input_2d
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input_2d
.
contiguous
()
weight
=
weight
.
contiguous
()
n_rows
,
n_cols
=
input_2d
.
shape
output
=
torch
.
empty_like
(
input_2d
)
BLOCK_SIZE
=
1024
grid
=
(
n_rows
,)
_rms_norm_kernel
[
grid
](
input_2d
,
weight
,
output
,
input_2d
.
stride
(
0
),
output
.
stride
(
0
),
n_cols
,
eps
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
output
.
reshape
(
original_shape
)
def
rms_norm_batch_invariant
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return
rms_norm
(
input
,
weight
,
eps
=
eps
)
_batch_invariant_MODE
=
False
_batch_invariant_LIB
=
None
_original_torch_bmm
=
None
...
...
python/sglang/srt/layers/layernorm.py
View file @
0380ca82
...
...
@@ -20,7 +20,12 @@ import torch
import
torch.nn
as
nn
from
packaging.version
import
Version
from
sglang.srt.batch_invariant_ops
import
(
is_batch_invariant_mode_enabled
,
rms_norm_batch_invariant
,
)
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
...
...
@@ -90,8 +95,6 @@ class RMSNorm(CustomOp):
)
if
_use_aiter
:
self
.
_forward_method
=
self
.
forward_aiter
if
get_bool_env_var
(
"SGLANG_ENABLE_DETERMINISTIC_INFERENCE"
):
self
.
_forward_method
=
self
.
forward_native
def
forward_cuda
(
self
,
...
...
@@ -100,6 +103,17 @@ class RMSNorm(CustomOp):
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
if
is_batch_invariant_mode_enabled
():
if
(
residual
is
not
None
or
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
):
return
self
.
forward_native
(
x
,
residual
)
return
rms_norm_batch_invariant
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment