Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9a0d0b75
Unverified
Commit
9a0d0b75
authored
Aug 31, 2025
by
Vincent Zhong
Committed by
GitHub
Aug 31, 2025
Browse files
[Performance] Improve Qwen RMSNorm by replacing with native RMSNorm op (#9709)
parent
ba861293
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
18 deletions
+44
-18
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+44
-18
No files found.
python/sglang/srt/models/qwen2_5_vl.py
View file @
9a0d0b75
...
@@ -31,7 +31,6 @@ import torch.nn as nn
...
@@ -31,7 +31,6 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen2.modeling_qwen2
import
Qwen2RMSNorm
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
Qwen2_5_VLVisionConfig
,
...
@@ -43,6 +42,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
...
@@ -43,6 +42,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
...
@@ -122,8 +122,8 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -122,8 +122,8 @@ class Qwen2_5_VisionBlock(nn.Module):
super
().
__init__
()
super
().
__init__
()
if
norm_layer
is
None
:
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
norm1
=
Qwen2
RMSNorm
(
dim
,
eps
=
1e-6
)
self
.
norm1
=
RMSNorm
(
dim
,
eps
=
1e-6
)
self
.
norm2
=
Qwen2
RMSNorm
(
dim
,
eps
=
1e-6
)
self
.
norm2
=
RMSNorm
(
dim
,
eps
=
1e-6
)
if
attn_implementation
is
None
:
if
attn_implementation
is
None
:
softmax_in_single_precision
=
False
softmax_in_single_precision
=
False
...
@@ -174,18 +174,29 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -174,18 +174,29 @@ class Qwen2_5_VisionBlock(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
norm1
(
x
)
S
,
B
,
H
=
x
.
shape
hidden_states
=
rearrange
(
hidden_states
,
"s b ... -> b s ..."
)
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d
=
x
.
reshape
(
-
1
,
H
)
hidden_states
=
self
.
norm1
(
x2d
).
reshape
(
S
,
B
,
H
)
# Attention expects [B, S, H]
hidden_states
=
rearrange
(
hidden_states
,
"s b h -> b s h"
)
attn
=
self
.
attn
(
attn
=
self
.
attn
(
hidden_states
,
hidden_states
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
position_embeddings
=
position_embeddings
,
position_embeddings
=
position_embeddings
,
)
)
attn
=
rearrange
(
attn
,
"b s ... -> s b ..."
)
attn
=
rearrange
(
attn
,
"b s h -> s b h"
)
x
=
x
+
attn
norm2
=
self
.
norm2
(
x
)
# norm2 with fused residual-add: also 2D
mlp
=
self
.
mlp
(
norm2
)
attn2d
=
attn
.
reshape
(
-
1
,
H
)
x
=
x
+
mlp
x_norm_2d
,
x_after_add_2d
=
self
.
norm2
(
x2d
,
residual
=
attn2d
)
x_norm
=
x_norm_2d
.
reshape
(
S
,
B
,
H
)
x_after_add
=
x_after_add_2d
.
reshape
(
S
,
B
,
H
)
# MLP and final residual
mlp_out
=
self
.
mlp
(
x_norm
)
x
=
x_after_add
+
mlp_out
return
x
return
x
...
@@ -201,7 +212,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
...
@@ -201,7 +212,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
ln_q
=
Qwen2
RMSNorm
(
context_dim
,
eps
=
1e-6
)
self
.
ln_q
=
RMSNorm
(
context_dim
,
eps
=
1e-6
)
self
.
mlp
=
nn
.
ModuleList
(
self
.
mlp
=
nn
.
ModuleList
(
[
[
ColumnParallelLinear
(
ColumnParallelLinear
(
...
@@ -223,11 +234,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
...
@@ -223,11 +234,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
ln_q
(
x
)
# x expected shape: [S, B, context_dim]
x
=
x
.
view
(
-
1
,
self
.
hidden_size
)
S
,
B
,
D
=
x
.
shape
x2d
=
x
.
reshape
(
-
1
,
D
)
x2d
=
self
.
ln_q
(
x2d
)
# RMSNorm expects 2D
x2d
=
x2d
.
view
(
-
1
,
self
.
hidden_size
)
# group into spatial_merge_unit
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
x_parallel
,
_
=
mlp_fc1
(
x
)
x_parallel
,
_
=
mlp_fc1
(
x
2d
)
x_parallel
=
mlp_act
(
x_parallel
)
x_parallel
=
mlp_act
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
return
out
return
out
...
@@ -394,6 +407,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -394,6 +407,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
)
cu_window_seqlens
=
torch
.
unique_consecutive
(
cu_window_seqlens
)
cu_window_seqlens
=
torch
.
unique_consecutive
(
cu_window_seqlens
)
# Move window_index to the same device as x before using it to index x
window_index
=
window_index
.
to
(
device
=
x
.
device
)
# Ensure rotary_pos_emb is on the same device/dtype as x
rotary_pos_emb
=
rotary_pos_emb
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
seq_len
,
_
=
x
.
size
()
seq_len
,
_
=
x
.
size
()
x
=
x
.
reshape
(
seq_len
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
)
x
=
x
.
reshape
(
seq_len
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
)
...
@@ -406,12 +425,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -406,12 +425,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
seq_len
,
-
1
)
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
seq_len
,
-
1
)
emb
=
torch
.
cat
((
rotary_pos_emb
,
rotary_pos_emb
),
dim
=-
1
)
emb
=
torch
.
cat
((
rotary_pos_emb
,
rotary_pos_emb
),
dim
=-
1
)
position_embeddings
=
(
emb
.
cos
(),
emb
.
sin
())
position_embeddings
=
(
emb
.
cos
(),
emb
.
sin
())
# After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
position_embeddings
=
(
position_embeddings
[
0
].
to
(
x
.
device
,
x
.
dtype
),
position_embeddings
[
1
].
to
(
x
.
device
,
x
.
dtype
),
)
# compute cu_seqlens
# compute cu_seqlens
- move cu_seqlens to GPU and make it int32
cu_seqlens
=
torch
.
cat
(
cu_seqlens
=
torch
.
cat
(
[
[
torch
.
tensor
([
0
],
device
=
grid_thw
.
device
),
torch
.
tensor
([
0
],
device
=
x
.
device
,
dtype
=
torch
.
int32
),
(
grid_thw
[:,
0
]
*
grid_thw
[:,
1
]
*
grid_thw
[:,
2
]).
cumsum
(
dim
=
0
),
(
grid_thw
[:,
0
]
*
grid_thw
[:,
1
]
*
grid_thw
[:,
2
])
.
cumsum
(
dim
=
0
)
.
to
(
device
=
x
.
device
,
dtype
=
torch
.
int32
),
]
]
)
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment