Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0380ca82
"vscode:/vscode.git/clone" did not exist on "abc6f88b22e021a4f4739022f912c63effe6a6f3"
Unverified
Commit
0380ca82
authored
Oct 29, 2025
by
Yuzhen Zhou
Committed by
GitHub
Oct 28, 2025
Browse files
Add Batch‑Invariant RMSNorm (#12144)
parent
ec92b0ce
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
138 additions
and
2 deletions
+138
-2
python/sglang/srt/batch_invariant_ops/__init__.py
python/sglang/srt/batch_invariant_ops/__init__.py
+2
-0
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+120
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+16
-2
No files found.
python/sglang/srt/batch_invariant_ops/__init__.py
View file @
0380ca82
...
...
@@ -9,6 +9,7 @@ from .batch_invariant_ops import (
log_softmax
,
matmul_persistent
,
mean_dim
,
rms_norm_batch_invariant
,
set_batch_invariant_mode
,
)
...
...
@@ -24,4 +25,5 @@ __all__ = [
"mean_dim"
,
"get_batch_invariant_attention_block_size"
,
"AttentionBlockSize"
,
"rms_norm_batch_invariant"
,
]
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
0380ca82
...
...
@@ -579,6 +579,126 @@ def bmm_batch_invariant(a, b, *, out=None):
)
@
triton
.
jit
def
_rms_norm_kernel
(
input_ptr
,
weight_ptr
,
output_ptr
,
input_row_stride
,
output_row_stride
,
n_cols
,
eps
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
row_start_ptr
=
input_ptr
+
row_idx
*
input_row_stride
output_row_start_ptr
=
output_ptr
+
row_idx
*
output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq
=
tl
.
zeros
([
1
],
dtype
=
tl
.
float32
)
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=
0.0
)
# Convert to float32 for accumulation to prevent overflow
vals_f32
=
vals
.
to
(
tl
.
float32
)
sq_vals
=
vals_f32
*
vals_f32
sum_sq
+=
tl
.
sum
(
tl
.
where
(
mask
,
sq_vals
,
0.0
))
# Step 2: Compute RMS (root mean square) in float32
mean_sq
=
sum_sq
/
n_cols
rms
=
tl
.
sqrt
(
mean_sq
+
eps
)
inv_rms
=
1.0
/
rms
# Step 3: Normalize and apply weight
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=
0.0
)
weight
=
tl
.
load
(
weight_ptr
+
col_idx
,
mask
=
mask
,
other
=
1.0
)
# Compute in float32 then convert back to input dtype
vals_f32
=
vals
.
to
(
tl
.
float32
)
weight_f32
=
weight
.
to
(
tl
.
float32
)
output_f32
=
vals_f32
*
inv_rms
*
weight_f32
output
=
output_f32
.
to
(
vals
.
dtype
)
tl
.
store
(
output_row_start_ptr
+
col_idx
,
output
,
mask
=
mask
)
def
rms_norm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""
Compute RMS normalization using Triton kernel.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert
weight
.
dim
()
==
1
,
"Weight must be 1-dimensional"
assert
input
.
shape
[
-
1
]
==
weight
.
shape
[
0
],
(
f
"Input last dimension (
{
input
.
shape
[
-
1
]
}
) must match "
f
"weight dimension (
{
weight
.
shape
[
0
]
}
)"
)
# Flatten all dimensions except the last one
original_shape
=
input
.
shape
input_2d
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input_2d
.
contiguous
()
weight
=
weight
.
contiguous
()
n_rows
,
n_cols
=
input_2d
.
shape
output
=
torch
.
empty_like
(
input_2d
)
BLOCK_SIZE
=
1024
grid
=
(
n_rows
,)
_rms_norm_kernel
[
grid
](
input_2d
,
weight
,
output
,
input_2d
.
stride
(
0
),
output
.
stride
(
0
),
n_cols
,
eps
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
output
.
reshape
(
original_shape
)
def
rms_norm_batch_invariant
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return
rms_norm
(
input
,
weight
,
eps
=
eps
)
_batch_invariant_MODE
=
False
_batch_invariant_LIB
=
None
_original_torch_bmm
=
None
...
...
python/sglang/srt/layers/layernorm.py
View file @
0380ca82
...
...
@@ -20,7 +20,12 @@ import torch
import
torch.nn
as
nn
from
packaging.version
import
Version
from
sglang.srt.batch_invariant_ops
import
(
is_batch_invariant_mode_enabled
,
rms_norm_batch_invariant
,
)
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
...
...
@@ -90,8 +95,6 @@ class RMSNorm(CustomOp):
)
if
_use_aiter
:
self
.
_forward_method
=
self
.
forward_aiter
if
get_bool_env_var
(
"SGLANG_ENABLE_DETERMINISTIC_INFERENCE"
):
self
.
_forward_method
=
self
.
forward_native
def
forward_cuda
(
self
,
...
...
@@ -100,6 +103,17 @@ class RMSNorm(CustomOp):
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
if
is_batch_invariant_mode_enabled
():
if
(
residual
is
not
None
or
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
):
return
self
.
forward_native
(
x
,
residual
)
return
rms_norm_batch_invariant
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment