Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86b04d25
Unverified
Commit
86b04d25
authored
Oct 17, 2025
by
Mick
Committed by
GitHub
Oct 16, 2025
Browse files
model: qwen3-omni (thinker-only) (#10911)
Co-authored-by:
Xinyuan Tong
<
xinyuantong.cs@gmail.com
>
parent
85ebeecf
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1947 additions
and
328 deletions
+1947
-328
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/configs/qwen3_omni.py
python/sglang/srt/configs/qwen3_omni.py
+613
-0
python/sglang/srt/configs/qwen3_vl.py
python/sglang/srt/configs/qwen3_vl.py
+0
-10
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+357
-2
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+18
-16
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-0
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+2
-1
python/sglang/srt/models/qwen3_omni_moe.py
python/sglang/srt/models/qwen3_omni_moe.py
+661
-0
python/sglang/srt/models/qwen3_vl.py
python/sglang/srt/models/qwen3_vl.py
+38
-24
python/sglang/srt/models/qwen3_vl_moe.py
python/sglang/srt/models/qwen3_vl_moe.py
+53
-168
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+2
-1
python/sglang/srt/multimodal/processors/qwen_vl.py
python/sglang/srt/multimodal/processors/qwen_vl.py
+40
-6
test/srt/test_vision_openai_server_a.py
test/srt/test_vision_openai_server_a.py
+2
-1
test/srt/test_vision_openai_server_b.py
test/srt/test_vision_openai_server_b.py
+25
-1
test/srt/test_vision_openai_server_common.py
test/srt/test_vision_openai_server_common.py
+132
-96
No files found.
python/sglang/srt/configs/model_config.py
View file @
86b04d25
...
...
@@ -853,6 +853,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"Qwen3OmniMoeForConditionalGeneration"
,
"KimiVLForConditionalGeneration"
,
"InternVLChatModel"
,
"InternS1ForConditionalGeneration"
,
...
...
python/sglang/srt/configs/qwen3_omni.py
0 → 100644
View file @
86b04d25
from
transformers
import
PretrainedConfig
from
transformers.configuration_utils
import
layer_type_validation
from
transformers.modeling_rope_utils
import
rope_config_validation
from
sglang.utils
import
logger
class
Qwen3OmniMoeAudioEncoderConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe_audio_encoder"
def
__init__
(
self
,
num_mel_bins
=
128
,
encoder_layers
=
32
,
encoder_attention_heads
=
20
,
encoder_ffn_dim
=
5120
,
d_model
=
1280
,
dropout
=
0
,
attention_dropout
=
0
,
activation_function
=
"gelu"
,
activation_dropout
=
0
,
scale_embedding
=
False
,
initializer_range
=
0.02
,
max_source_positions
=
1500
,
n_window
=
100
,
output_dim
=
3584
,
n_window_infer
=
400
,
conv_chunksize
=
500
,
downsample_hidden_size
=
480
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
num_mel_bins
=
num_mel_bins
self
.
d_model
=
d_model
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_function
=
activation_function
self
.
activation_dropout
=
activation_dropout
self
.
num_hidden_layers
=
encoder_layers
self
.
initializer_range
=
initializer_range
self
.
scale_embedding
=
(
scale_embedding
# scale factor will be sqrt(d_model) if True
)
self
.
max_source_positions
=
max_source_positions
self
.
n_window
=
n_window
self
.
output_dim
=
output_dim
self
.
n_window_infer
=
n_window_infer
self
.
conv_chunksize
=
conv_chunksize
self
.
downsample_hidden_size
=
downsample_hidden_size
class
Qwen3OmniMoeVisionEncoderConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe_vision_encoder"
base_config_key
=
"vision_config"
def
__init__
(
self
,
depth
=
27
,
hidden_size
=
1152
,
hidden_act
=
"gelu_pytorch_tanh"
,
intermediate_size
=
4304
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
16
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
out_hidden_size
=
3584
,
num_position_embeddings
=
2304
,
deepstack_visual_indexes
=
[
8
,
16
,
24
],
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
out_hidden_size
=
out_hidden_size
self
.
num_position_embeddings
=
num_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
deepstack_visual_indexes
=
deepstack_visual_indexes
class
Qwen3OmniMoeTextConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe_text"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `Qwen3OmniMoeText`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.experts.*.gate_proj"
:
"colwise"
,
"layers.*.mlp.experts.*.up_proj"
:
"colwise"
,
"layers.*.mlp.experts.*.down_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
=
3584
,
hidden_size
=
2048
,
intermediate_size
=
18944
,
num_hidden_layers
=
28
,
num_attention_heads
=
28
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
1000000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
sliding_window
=
None
,
attention_dropout
=
0
,
decoder_sparse_step
=
1
,
moe_intermediate_size
=
768
,
num_experts_per_tok
=
8
,
num_experts
=
128
,
norm_topk_prob
=
True
,
output_router_logits
=
False
,
router_aux_loss_coef
=
0.001
,
mlp_only_layers
=
None
,
**
kwargs
,
):
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
sliding_window
=
sliding_window
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
# MoE arguments
self
.
decoder_sparse_step
=
decoder_sparse_step
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts
=
num_experts
self
.
norm_topk_prob
=
norm_topk_prob
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
self
.
mlp_only_layers
=
[]
if
mlp_only_layers
is
None
else
mlp_only_layers
class
Qwen3OmniMoeThinkerConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe_thinker"
attribute_map
=
{
"image_token_id"
:
"image_token_index"
,
"video_token_id"
:
"video_token_index"
,
"audio_token_id"
:
"audio_token_index"
,
}
sub_configs
=
{
"audio_config"
:
Qwen3OmniMoeAudioEncoderConfig
,
"vision_config"
:
Qwen3OmniMoeVisionEncoderConfig
,
"text_config"
:
Qwen3OmniMoeTextConfig
,
}
def
__init__
(
self
,
audio_config
=
None
,
vision_config
=
None
,
text_config
=
None
,
audio_token_id
=
151646
,
image_token_id
=
151655
,
video_token_id
=
151656
,
position_id_per_seconds
=
25
,
audio_start_token_id
=
151647
,
user_token_id
=
872
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
user_token_id
=
user_token_id
self
.
position_id_per_seconds
=
position_id_per_seconds
self
.
audio_start_token_id
=
audio_start_token_id
self
.
initializer_range
=
initializer_range
if
isinstance
(
vision_config
,
dict
):
vision_config
=
Qwen3OmniMoeVisionEncoderConfig
(
**
vision_config
)
elif
vision_config
is
None
:
vision_config
=
Qwen3OmniMoeVisionEncoderConfig
()
self
.
vision_config
=
vision_config
if
isinstance
(
audio_config
,
dict
):
audio_config
=
Qwen3OmniMoeAudioEncoderConfig
(
**
audio_config
)
elif
audio_config
is
None
:
audio_config
=
Qwen3OmniMoeAudioEncoderConfig
()
self
.
audio_config
=
audio_config
if
isinstance
(
text_config
,
dict
):
text_config
=
Qwen3OmniMoeTextConfig
(
**
text_config
)
elif
text_config
is
None
:
text_config
=
Qwen3OmniMoeTextConfig
()
self
.
text_config
=
text_config
self
.
audio_token_id
=
audio_token_id
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
class
Qwen3OmniMoeTalkerCodePredictorConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe_talker_code_predictor"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
=
2048
,
hidden_size
=
1024
,
intermediate_size
=
3072
,
num_hidden_layers
=
5
,
num_attention_heads
=
16
,
num_key_value_heads
=
8
,
head_dim
=
128
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
0.000001
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
10000
,
rope_scaling
=
None
,
attention_bias
=
False
,
sliding_window
=
None
,
layer_types
=
None
,
attention_dropout
=
0
,
num_code_groups
=
32
,
**
kwargs
,
):
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
sliding_window
=
sliding_window
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
head_dim
=
head_dim
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
self
.
layer_types
=
layer_types
if
self
.
layer_types
is
None
:
self
.
layer_types
=
[
(
"sliding_attention"
if
self
.
sliding_window
is
not
None
and
i
>=
self
.
max_window_layers
else
"full_attention"
)
for
i
in
range
(
self
.
num_hidden_layers
)
]
layer_type_validation
(
self
.
layer_types
,
self
.
num_hidden_layers
)
self
.
num_code_groups
=
num_code_groups
class
Qwen3OmniMoeTalkerTextConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe_talker_text"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.experts.*.gate_proj"
:
"colwise"
,
"layers.*.mlp.experts.*.up_proj"
:
"colwise"
,
"layers.*.mlp.experts.*.down_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
=
3072
,
hidden_size
=
1024
,
intermediate_size
=
2048
,
num_hidden_layers
=
20
,
num_attention_heads
=
16
,
num_key_value_heads
=
2
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
0.000001
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
10000
,
rope_scaling
=
None
,
attention_bias
=
False
,
sliding_window
=
None
,
attention_dropout
=
0
,
decoder_sparse_step
=
1
,
moe_intermediate_size
=
384
,
num_experts_per_tok
=
8
,
num_experts
=
128
,
norm_topk_prob
=
False
,
output_router_logits
=
False
,
router_aux_loss_coef
=
0.001
,
mlp_only_layers
=
None
,
**
kwargs
,
):
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
sliding_window
=
sliding_window
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
# MoE arguments
self
.
decoder_sparse_step
=
decoder_sparse_step
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts
=
num_experts
self
.
norm_topk_prob
=
norm_topk_prob
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
self
.
mlp_only_layers
=
[]
if
mlp_only_layers
is
None
else
mlp_only_layers
class
Qwen3OmniMoeTalkerConfig
(
PretrainedConfig
):
sub_configs
=
{
"code_predictor_config"
:
Qwen3OmniMoeTalkerCodePredictorConfig
,
"text_config"
:
Qwen3OmniMoeTalkerTextConfig
,
}
def
__init__
(
self
,
code_predictor_config
=
None
,
text_config
=
None
,
num_code_groups
=
32
,
thinker_hidden_size
=
2048
,
codec_eos_token_id
=
4198
,
accept_hidden_layer
=
18
,
codec_nothink_id
=
4203
,
codec_think_bos_id
=
4204
,
codec_think_eos_id
=
4205
,
codec_pad_id
=
4196
,
codec_bos_id
=
4197
,
audio_token_id
=
151646
,
image_token_id
=
151655
,
video_token_id
=
151656
,
vision_start_token_id
=
151652
,
position_id_per_seconds
=
25
,
audio_start_token_id
=
151669
,
speaker_id
=
None
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
if
code_predictor_config
is
None
:
code_predictor_config
=
{}
self
.
code_predictor_config
=
Qwen3OmniMoeTalkerCodePredictorConfig
()
logger
.
info
(
"code_predictor_config is None. Initializing code_predictor_config model with default values"
)
elif
isinstance
(
code_predictor_config
,
Qwen3OmniMoeTalkerCodePredictorConfig
):
self
.
code_predictor_config
=
code_predictor_config
else
:
self
.
code_predictor_config
=
Qwen3OmniMoeTalkerCodePredictorConfig
(
**
code_predictor_config
)
if
text_config
is
None
:
text_config
=
{}
self
.
text_config
=
Qwen3OmniMoeTalkerTextConfig
()
logger
.
info
(
"talker text_config is None. Initializing talker text model with default values"
)
elif
isinstance
(
text_config
,
Qwen3OmniMoeTalkerTextConfig
):
self
.
text_config
=
text_config
else
:
self
.
text_config
=
Qwen3OmniMoeTalkerTextConfig
(
**
text_config
)
self
.
num_code_groups
=
num_code_groups
self
.
thinker_hidden_size
=
thinker_hidden_size
self
.
codec_eos_token_id
=
codec_eos_token_id
self
.
accept_hidden_layer
=
accept_hidden_layer
self
.
codec_nothink_id
=
codec_nothink_id
self
.
codec_think_bos_id
=
codec_think_bos_id
self
.
codec_think_eos_id
=
codec_think_eos_id
self
.
codec_pad_id
=
codec_pad_id
self
.
codec_bos_id
=
codec_bos_id
self
.
audio_token_id
=
audio_token_id
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
self
.
position_id_per_seconds
=
position_id_per_seconds
self
.
audio_start_token_id
=
audio_start_token_id
self
.
vision_start_token_id
=
vision_start_token_id
self
.
speaker_id
=
speaker_id
class
Qwen3OmniMoeCode2WavConfig
(
PretrainedConfig
):
def
__init__
(
self
,
codebook_size
=
2048
,
hidden_size
=
1024
,
max_position_embeddings
=
8000
,
rope_theta
=
10000
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
attention_bias
=
False
,
sliding_window
=
72
,
intermediate_size
=
3072
,
hidden_act
=
"silu"
,
layer_scale_initial_scale
=
0.01
,
rms_norm_eps
=
1e-5
,
num_hidden_layers
=
8
,
num_quantizers
=
16
,
upsample_rates
=
(
8
,
5
,
4
,
3
),
upsampling_ratios
=
(
2
,
2
),
decoder_dim
=
1536
,
attention_dropout
=
0.0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
codebook_size
=
codebook_size
self
.
hidden_size
=
hidden_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
rope_theta
=
rope_theta
self
.
num_attention_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
attention_bias
=
attention_bias
self
.
sliding_window
=
sliding_window
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
layer_scale_initial_scale
=
layer_scale_initial_scale
self
.
rms_norm_eps
=
rms_norm_eps
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_quantizers
=
num_quantizers
self
.
upsample_rates
=
upsample_rates
self
.
upsampling_ratios
=
upsampling_ratios
self
.
decoder_dim
=
decoder_dim
self
.
attention_dropout
=
attention_dropout
@
property
def
layer_types
(
self
):
"""
All layer in code2wav should be sliding attention
"""
return
[
"sliding_attention"
]
*
self
.
num_hidden_layers
class
Qwen3OmniMoeConfig
(
PretrainedConfig
):
model_type
=
"qwen3_omni_moe"
sub_configs
=
{
"thinker_config"
:
Qwen3OmniMoeThinkerConfig
,
"talker_config"
:
Qwen3OmniMoeTalkerConfig
,
"code2wav_config"
:
Qwen3OmniMoeCode2WavConfig
,
}
def
__init__
(
self
,
thinker_config
=
None
,
talker_config
=
None
,
code2wav_config
=
None
,
enable_audio_output
=
True
,
im_start_token_id
=
151644
,
im_end_token_id
=
151645
,
tts_pad_token_id
=
151671
,
tts_bos_token_id
=
151672
,
tts_eos_token_id
=
151673
,
system_token_id
=
8948
,
user_token_id
=
872
,
assistant_token_id
=
77091
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
if
thinker_config
is
None
:
thinker_config
=
{}
logger
.
info
(
"thinker_config is None. Initializing thinker model with default values"
)
if
talker_config
is
None
:
talker_config
=
{}
logger
.
info
(
"talker_config is None. Initializing talker model with default values"
)
if
code2wav_config
is
None
:
code2wav_config
=
{}
logger
.
info
(
"code2wav_config is None. Initializing code2wav model with default values"
)
self
.
thinker_config
=
Qwen3OmniMoeThinkerConfig
(
**
thinker_config
)
self
.
talker_config
=
Qwen3OmniMoeTalkerConfig
(
**
talker_config
)
self
.
code2wav_config
=
Qwen3OmniMoeCode2WavConfig
(
**
code2wav_config
)
self
.
enable_audio_output
=
enable_audio_output
self
.
im_start_token_id
=
im_start_token_id
self
.
im_end_token_id
=
im_end_token_id
self
.
tts_pad_token_id
=
tts_pad_token_id
self
.
tts_bos_token_id
=
tts_bos_token_id
self
.
tts_eos_token_id
=
tts_eos_token_id
self
.
system_token_id
=
system_token_id
self
.
user_token_id
=
user_token_id
self
.
assistant_token_id
=
assistant_token_id
def
get_text_config
(
self
,
decoder
=
False
)
->
"PretrainedConfig"
:
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return
self
.
thinker_config
.
get_text_config
()
python/sglang/srt/configs/qwen3_vl.py
View file @
86b04d25
from
typing
import
Optional
,
Union
from
transformers
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
...
...
@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
self
.
vision_start_token_id
=
vision_start_token_id
self
.
vision_end_token_id
=
vision_end_token_id
super
().
__init__
(
**
kwargs
,
tie_word_embeddings
=
tie_word_embeddings
)
__all__
=
[
"Qwen3VLMoeConfig"
,
"Qwen3VLMoeVisionConfig"
,
"Qwen3VLConfig"
,
"Qwen3VLVisionConfig"
,
]
python/sglang/srt/layers/rotary_embedding.py
View file @
86b04d25
...
...
@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
model_type
==
"qwen3_omni_moe"
:
# For qwen3-omni
return
MRotaryEmbedding
.
get_rope_index_qwen3_omni
(
spatial_merge_size
,
image_token_id
,
video_token_id
,
vision_start_token_id
,
tokens_per_second
,
input_ids
,
image_grid_thw
,
video_grid_thw
,
second_per_grid_ts
,
**
kwargs
,
)
if
(
model_type
.
startswith
(
"qwen3_vl"
)
or
model_type
.
startswith
(
"qwen3_vl_moe"
)
)
and
video_grid_thw
is
not
None
:
...
...
@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw
,
video_grid_thw
[:,
0
],
dim
=
0
)
video_grid_thw
[:,
0
]
=
1
mrope_position_deltas
=
[]
if
input_ids
is
not
None
and
(
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
...
...
@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long
=
time_tensor
.
long
()
t_index
=
time_tensor_long
.
flatten
()
elif
model_type
in
(
"qwen2_vl"
,
"qwen3_vl"
,
"qwen3_vl_moe"
):
elif
model_type
in
(
"qwen2_vl"
,
"qwen3_vl"
,
"qwen3_vl_moe"
,
):
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
...
...
@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
.
flatten
()
)
else
:
raise
RuntimeError
(
"Unimplemented"
)
raise
RuntimeError
(
f
"Unimplemented
model type:
{
model_type
}
"
)
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
...
...
@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_deltas
=
max_position_ids
+
1
-
s
return
position_ids
,
mrope_position_deltas
@
staticmethod
def
get_rope_index_qwen3_omni
(
spatial_merge_size
:
int
,
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
tokens_per_second
:
Optional
[
int
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
image_grid_thw
:
Optional
[
torch
.
LongTensor
]
=
None
,
video_grid_thw
:
Optional
[
torch
.
LongTensor
]
=
None
,
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For qwen3-omni
audio_token_id
=
kwargs
[
"audio_token_id"
]
audio_start_token_id
=
kwargs
[
"audio_start_token_id"
]
position_id_per_seconds
=
kwargs
[
"position_id_per_seconds"
]
use_audio_in_video
=
kwargs
.
get
(
"use_audio_in_video"
,
False
)
audio_seqlens
=
kwargs
.
get
(
"audio_seqlens"
,
None
)
second_per_grids
=
second_per_grid_ts
mrope_position_deltas
=
[]
if
input_ids
is
not
None
and
(
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
):
total_input_ids
=
input_ids
position_ids
=
torch
.
zeros
(
3
,
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
],
dtype
=
torch
.
float
,
device
=
input_ids
.
device
,
)
image_idx
,
video_idx
,
audio_idx
=
0
,
0
,
0
for
i
,
current_input_ids
in
enumerate
(
total_input_ids
):
image_nums
,
video_nums
,
audio_nums
=
0
,
0
,
0
vision_start_indices
=
torch
.
argwhere
(
current_input_ids
==
vision_start_token_id
).
squeeze
(
1
)
if
vision_start_indices
.
numel
()
>
0
:
vision_tokens
=
current_input_ids
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
(
vision_tokens
==
audio_start_token_id
).
sum
()
if
use_audio_in_video
else
(
vision_tokens
==
video_token_id
).
sum
()
)
audio_nums
=
torch
.
sum
(
current_input_ids
==
audio_start_token_id
)
input_tokens
=
current_input_ids
.
tolist
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
,
remain_audios
=
(
image_nums
,
video_nums
,
audio_nums
,
)
multimodal_nums
=
(
image_nums
+
audio_nums
if
use_audio_in_video
else
image_nums
+
video_nums
+
audio_nums
)
for
_
in
range
(
multimodal_nums
):
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
ed_vision_start
=
(
input_tokens
.
index
(
vision_start_token_id
,
st
)
if
(
(
image_token_id
in
input_tokens
or
video_token_id
in
input_tokens
)
and
(
remain_videos
>
0
or
remain_images
>
0
)
)
else
len
(
input_tokens
)
+
1
)
ed_audio_start
=
(
input_tokens
.
index
(
audio_start_token_id
,
st
)
if
(
audio_token_id
in
input_tokens
and
remain_audios
>
0
)
else
len
(
input_tokens
)
+
1
)
min_ed
=
min
(
ed_vision_start
,
ed_audio_start
)
text_len
=
min_ed
-
st
if
text_len
!=
0
:
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
+=
text_len
# Audio in Video
if
(
min_ed
==
ed_vision_start
and
ed_vision_start
+
1
==
ed_audio_start
):
bos_len
,
eos_len
=
2
,
2
else
:
bos_len
,
eos_len
=
1
,
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
+=
bos_len
# Audio Only
if
min_ed
==
ed_audio_start
:
audio_len
=
MRotaryEmbedding
.
_get_feat_extract_output_lengths
(
audio_seqlens
[
audio_idx
]
)
llm_pos_ids
=
(
torch
.
arange
(
audio_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st
+=
int
(
text_len
+
bos_len
+
audio_len
+
eos_len
)
audio_idx
+=
1
remain_audios
-=
1
# Image Only
elif
(
min_ed
==
ed_vision_start
and
current_input_ids
[
ed_vision_start
+
1
]
==
image_token_id
):
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
1
*
position_id_per_seconds
).
float
()
llm_pos_ids
=
MRotaryEmbedding
.
_get_llm_pos_ids_for_vision
(
st_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
,
input_ids
.
device
,
)
image_len
=
image_grid_thw
[
image_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st
+=
int
(
text_len
+
bos_len
+
image_len
+
eos_len
)
image_idx
+=
1
remain_images
-=
1
# Video Only
elif
(
min_ed
==
ed_vision_start
and
current_input_ids
[
ed_vision_start
+
1
]
==
video_token_id
):
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grids
[
video_idx
].
cpu
().
float
()
*
position_id_per_seconds
).
float
()
llm_pos_ids
=
MRotaryEmbedding
.
_get_llm_pos_ids_for_vision
(
st_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
,
input_ids
.
device
,
)
video_len
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st
+=
int
(
text_len
+
bos_len
+
video_len
+
eos_len
)
video_idx
+=
1
remain_videos
-=
1
# Audio in Video
elif
(
min_ed
==
ed_vision_start
and
ed_vision_start
+
1
==
ed_audio_start
):
audio_len
=
MRotaryEmbedding
.
_get_feat_extract_output_lengths
(
audio_seqlens
[
audio_idx
]
)
audio_llm_pos_ids
=
(
torch
.
arange
(
audio_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grids
[
video_idx
].
cpu
().
float
()
*
position_id_per_seconds
).
float
()
video_llm_pos_ids
=
(
MRotaryEmbedding
.
_get_llm_pos_ids_for_vision
(
st_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
,
input_ids
.
device
,
)
)
video_data_index
,
audio_data_index
=
0
,
0
while
(
video_data_index
<
video_llm_pos_ids
.
shape
[
-
1
]
and
audio_data_index
<
audio_llm_pos_ids
.
shape
[
-
1
]
):
if
(
video_llm_pos_ids
[
0
][
video_data_index
]
<=
audio_llm_pos_ids
[
0
][
audio_data_index
]
):
llm_pos_ids_list
.
append
(
video_llm_pos_ids
[
:,
video_data_index
:
video_data_index
+
1
]
)
video_data_index
+=
1
else
:
llm_pos_ids_list
.
append
(
audio_llm_pos_ids
[
:,
audio_data_index
:
audio_data_index
+
1
]
)
audio_data_index
+=
1
if
video_data_index
<
video_llm_pos_ids
.
shape
[
-
1
]:
llm_pos_ids_list
.
append
(
video_llm_pos_ids
[
:,
video_data_index
:
video_llm_pos_ids
.
shape
[
-
1
]
]
)
if
audio_data_index
<
audio_llm_pos_ids
.
shape
[
-
1
]:
llm_pos_ids_list
.
append
(
audio_llm_pos_ids
[
:,
audio_data_index
:
audio_llm_pos_ids
.
shape
[
-
1
]
]
)
video_len
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
st
+=
int
(
text_len
+
bos_len
+
audio_len
+
video_len
+
eos_len
)
audio_idx
+=
1
video_idx
+=
1
remain_videos
-=
1
remain_audios
-=
1
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
if
st
<
len
(
input_tokens
):
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
[
item
.
float
()
for
item
in
llm_pos_ids_list
],
dim
=
1
).
reshape
(
3
,
-
1
)
position_ids
[...,
i
,
:]
=
llm_positions
.
to
(
position_ids
.
device
)
mrope_position_deltas
.
append
(
llm_positions
.
max
()
+
1
-
len
(
current_input_ids
)
)
mrope_position_deltas
=
torch
.
tensor
(
mrope_position_deltas
,
device
=
input_ids
.
device
).
unsqueeze
(
1
)
return
position_ids
,
mrope_position_deltas
else
:
s
=
input_ids
.
shape
[
1
]
position_ids
=
torch
.
arange
(
s
)
position_ids
=
(
position_ids
.
unsqueeze
(
0
).
expand
(
3
,
-
1
,
-
1
).
to
(
input_ids
.
device
)
)
max_position_ids
=
position_ids
.
max
(
0
,
keepdim
=
False
)[
0
].
max
(
-
1
,
keepdim
=
True
)[
0
]
mrope_position_deltas
=
max_position_ids
+
1
-
s
return
position_ids
,
mrope_position_deltas
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
@
staticmethod
def
get_rope_index_glm4v
(
...
...
@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
return
position_ids
,
mrope_position_deltas
# For qwen3-omni
@
staticmethod
def
_get_feat_extract_output_lengths
(
input_lengths
):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave
=
input_lengths
%
100
feat_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
output_lengths
=
(
((
feat_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
)
return
output_lengths
# For qwen3-omni
@
staticmethod
def
_get_llm_pos_ids_for_vision
(
st_idx
,
vision_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
,
device
):
grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
grid_h
,
device
=
device
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
len
(
t_index
),
-
1
,
grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
grid_w
,
device
=
device
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
len
(
t_index
),
grid_h
,
-
1
)
.
flatten
()
)
t_index
=
t_index
.
view
(
-
1
,
1
).
expand
(
-
1
,
grid_h
*
grid_w
).
flatten
()
llm_pos_ids
=
torch
.
stack
([
t_index
,
h_index
,
w_index
],
dim
=
0
)
+
st_idx
return
llm_pos_ids
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
...
...
python/sglang/srt/managers/mm_utils.py
View file @
86b04d25
...
...
@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
input_ids_tensor
[
input_ids_tensor
==
token_id
]
=
pad_value
ret_input_ids
=
input_ids_tensor
.
tolist
()
return
ret_input_ids
...
...
@@ -507,7 +506,7 @@ def embed_mm_inputs(
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
placeholder_tokens
:
dict
[
Modality
,
List
[
int
]]
=
None
,
use_deepstack
:
bool
=
False
,
use_deepstack
:
Dict
[
Modality
,
bool
]
=
{}
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
...
...
@@ -533,7 +532,9 @@ def embed_mm_inputs(
for
mm_inputs
in
mm_inputs_list
:
item_flatten_list
+=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
is
not
None
]
embeddings
,
masks
,
deepstack_embeddings
=
[],
[],
[]
# deepstack_embeddings: per-modality
modalities
,
embeddings
,
masks
,
deepstack_embeddings
=
[],
[],
[],
[]
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for
modality
in
Modality
.
all
():
...
...
@@ -549,7 +550,8 @@ def embed_mm_inputs(
# "image", "video", etc
modality_id
=
modality
.
name
.
lower
()
embedder
=
getattr
(
multimodal_model
,
f
"get_
{
modality_id
}
_feature"
,
None
)
if
len
(
items
)
!=
0
and
embedder
is
not
None
:
if
len
(
items
)
!=
0
:
assert
embedder
is
not
None
,
f
"no embedding method found for
{
modality
}
"
placeholder_tensor
=
torch
.
as_tensor
(
[
item
.
pad_value
for
item
in
items
],
device
=
input_ids
.
device
,
...
...
@@ -580,11 +582,12 @@ def embed_mm_inputs(
items_offset_list
=
items_offsets
,
)
if
use_deepstack
and
embedding
is
not
None
:
if
use_deepstack
.
get
(
modality
,
None
)
and
embedding
is
not
None
:
embedding
,
deepstack_embedding
=
(
multimodal_model
.
separate_deepstack_embeds
(
embedding
)
)
deepstack_embeddings
+=
[
deepstack_embedding
]
modalities
+=
[
modality
]
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
...
...
@@ -597,17 +600,14 @@ def embed_mm_inputs(
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
inputs_embeds
=
input_embedding
(
input_ids
)
# 4. scatter embeddings into input embedding
# deepstack embedding
if
use_deepstack
:
num_deepstack_embeddings
=
(
len
(
multimodal_model
.
deepstack_visual_indexes
)
if
use_deepstack
else
0
)
num_deepstack_embeddings
=
len
(
multimodal_model
.
deepstack_visual_indexes
)
deepstack_embedding_shape
=
inputs_embeds
.
shape
[:
-
1
]
+
(
inputs_embeds
.
shape
[
-
1
]
*
num_deepstack_embeddings
,
)
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
input_deepstack_embeds
=
torch
.
zeros
(
deepstack_embedding_shape
,
device
=
inputs_embeds
.
device
,
...
...
@@ -616,14 +616,16 @@ def embed_mm_inputs(
other_info
[
"input_deepstack_embeds"
]
=
input_deepstack_embeds
for
i
,
embedding
,
mask
in
zip
(
range
(
len
(
embeddings
)),
embeddings
,
masks
):
# 4. scatter embeddings into input embedding
for
i
,
modality
,
embedding
,
mask
in
zip
(
range
(
len
(
embeddings
)),
modalities
,
embeddings
,
masks
):
if
embedding
is
None
or
mask
is
None
:
continue
# in-place update
indices
=
torch
.
where
(
mask
.
squeeze
(
dim
=-
1
))[
0
]
inputs_embeds
[
indices
]
=
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
if
use_deepstack
:
if
use_deepstack
.
get
(
modality
,
None
):
input_deepstack_embeds
[
indices
]
=
deepstack_embeddings
[
i
].
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
...
...
@@ -640,7 +642,7 @@ def general_mm_embed_routine(
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
placeholder_tokens
:
Optional
[
dict
[
Modality
,
List
[
int
]]]
=
None
,
use_deepstack
:
bool
=
False
,
use_deepstack
:
Dict
[
Modality
,
bool
]
=
{}
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -652,7 +654,7 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
use_deepstack: Whether to use deepstack embeddings
for each modality, default False
**kwargs: Additional arguments passed to language model
Returns:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
86b04d25
...
...
@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
if
not
isinstance
(
obj
.
image_data
,
list
)
and
obj
.
image_data
:
if
obj
.
image_data
is
not
None
and
not
isinstance
(
obj
.
image_data
,
list
):
obj
.
image_data
=
[
obj
.
image_data
]
if
not
isinstance
(
obj
.
audio_data
,
list
)
and
obj
.
audio_data
:
if
obj
.
audio_data
is
not
None
and
not
isinstance
(
obj
.
audio_data
,
list
):
obj
.
audio_data
=
[
obj
.
audio_data
]
mm_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_data
=
obj
.
image_data
,
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
86b04d25
...
...
@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
86b04d25
...
...
@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
config
:
Qwen3MoeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
decoder_layer_type
=
Qwen3MoeDecoderLayer
,
)
->
None
:
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
super
().
__init__
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
decoder_layer_type
=
Qwen3MoeD
ecoder
L
ayer
,
decoder_layer_type
=
d
ecoder
_l
ayer
_type
,
alt_stream
=
alt_stream
,
)
...
...
python/sglang/srt/models/qwen3_omni_moe.py
0 → 100644
View file @
86b04d25
# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PreTrainedModel
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
from
sglang.srt.configs.qwen3_omni
import
(
Qwen3OmniMoeAudioEncoderConfig
,
Qwen3OmniMoeThinkerConfig
,
Qwen3OmniMoeVisionEncoderConfig
,
)
from
sglang.srt.configs.qwen3_vl
import
Qwen3VLMoeConfig
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen3_vl
import
Qwen3VLMoeVisionModel
from
sglang.srt.models.qwen3_vl_moe
import
(
Qwen3MoeLLMModel
,
Qwen3VLMoeForConditionalGeneration
,
load_fused_expert_weights
,
)
from
sglang.srt.utils
import
add_prefix
,
logger
class
Qwen3OmniMoeAudioEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen3OmniMoeAudioEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
embed_dim
=
config
.
d_model
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
VisionAttention
(
embed_dim
=
embed_dim
,
num_heads
=
config
.
encoder_attention_heads
,
projection_size
=
embed_dim
,
use_qkv_parallel
=
True
,
rotary_embed
=
"normal"
,
proj_bias
=
True
,
qkv_backend
=
"fa3"
,
softmax_in_single_precision
=
False
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
config
.
dropout
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
activation_dropout
=
config
.
activation_dropout
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
x
=
hidden_states
,
cu_seqlens
=
cu_seqlens
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
if
hidden_states
.
dtype
==
torch
.
float16
:
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
outputs
=
(
hidden_states
,)
return
outputs
class
SinusoidsPositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
length
,
channels
,
max_timescale
=
10000
):
super
().
__init__
()
if
channels
%
2
!=
0
:
raise
ValueError
(
"SinusoidsPositionEmbedding needs even channels input"
)
log_timescale_increment
=
np
.
log
(
max_timescale
)
/
(
channels
//
2
-
1
)
inv_timescales
=
torch
.
exp
(
-
log_timescale_increment
*
torch
.
arange
(
channels
//
2
).
float
()
)
scaled_time
=
(
torch
.
arange
(
length
)[:,
np
.
newaxis
]
*
inv_timescales
[
np
.
newaxis
,
:]
)
self
.
register_buffer
(
"positional_embedding"
,
torch
.
cat
([
torch
.
sin
(
scaled_time
),
torch
.
cos
(
scaled_time
)],
dim
=
1
),
persistent
=
False
,
)
def
forward
(
self
,
seqlen
:
int
):
return
self
.
positional_embedding
[:
seqlen
,
:]
def
_get_feat_extract_output_lengths
(
input_lengths
):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave
=
input_lengths
%
100
feat_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
output_lengths
=
(
((
feat_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
)
return
output_lengths
class
Qwen3OmniMoeAudioEncoder
(
PreTrainedModel
):
config
:
Qwen3OmniMoeAudioEncoderConfig
def
__init__
(
self
,
config
:
Qwen3OmniMoeAudioEncoderConfig
):
super
().
__init__
(
config
)
self
.
dropout
=
config
.
dropout
embed_dim
=
config
.
d_model
self
.
num_mel_bins
=
config
.
num_mel_bins
self
.
max_source_positions
=
config
.
max_source_positions
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
config
.
scale_embedding
else
1.0
self
.
n_window
=
config
.
n_window
self
.
positional_embedding
=
SinusoidsPositionEmbedding
(
self
.
max_source_positions
,
embed_dim
)
self
.
layers
=
nn
.
ModuleList
(
[
Qwen3OmniMoeAudioEncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)
]
)
self
.
ln_post
=
nn
.
LayerNorm
(
config
.
d_model
)
self
.
gradient_checkpointing
=
False
self
.
conv2d1
=
nn
.
Conv2d
(
1
,
config
.
downsample_hidden_size
,
3
,
2
,
padding
=
1
)
self
.
conv2d2
=
nn
.
Conv2d
(
config
.
downsample_hidden_size
,
config
.
downsample_hidden_size
,
3
,
2
,
padding
=
1
,
)
self
.
conv2d3
=
nn
.
Conv2d
(
config
.
downsample_hidden_size
,
config
.
downsample_hidden_size
,
3
,
2
,
padding
=
1
,
)
self
.
conv_out
=
nn
.
Linear
(
config
.
downsample_hidden_size
*
((((
config
.
num_mel_bins
+
1
)
//
2
+
1
)
//
2
+
1
)
//
2
),
config
.
d_model
,
bias
=
False
,
)
self
.
proj1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
self
.
act
=
ACT2FN
[
config
.
activation_function
]
self
.
proj2
=
nn
.
Linear
(
config
.
d_model
,
config
.
output_dim
)
self
.
n_window_infer
=
self
.
config
.
n_window_infer
self
.
conv_chunksize
=
self
.
config
.
conv_chunksize
def
_freeze_parameters
(
self
):
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
_requires_grad
=
False
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
conv1
def
set_input_embeddings
(
self
,
value
:
nn
.
Module
):
self
.
conv1
=
value
def
forward
(
self
,
input_features
,
feature_lens
=
None
,
aftercnn_lens
=
None
,
):
r
"""
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length after cnn
"""
aftercnn_lens
=
_get_feat_extract_output_lengths
(
feature_lens
)
chunk_num
=
torch
.
ceil
(
feature_lens
/
(
self
.
n_window
*
2
)).
long
()
chunk_lengths
=
torch
.
tensor
(
[
self
.
n_window
*
2
]
*
chunk_num
.
sum
(),
dtype
=
torch
.
long
,
device
=
feature_lens
.
device
,
)
tail_chunk_index
=
F
.
pad
(
chunk_num
,
(
1
,
0
),
value
=-
1
).
cumsum
(
0
)[
1
:]
chunk_lengths
[
tail_chunk_index
]
=
feature_lens
%
(
self
.
n_window
*
2
)
chunk_lengths
[
chunk_lengths
==
0
]
=
self
.
n_window
*
2
chunk_list
=
input_features
.
T
.
split
(
chunk_lengths
.
tolist
(),
dim
=
0
)
padded_feature
=
nn
.
utils
.
rnn
.
pad_sequence
(
chunk_list
,
batch_first
=
True
).
transpose
(
1
,
2
)
feature_lens_after_cnn
=
_get_feat_extract_output_lengths
(
chunk_lengths
)
padded_mask_after_cnn
=
nn
.
utils
.
rnn
.
pad_sequence
(
[
torch
.
ones
(
length
,
dtype
=
torch
.
bool
,
device
=
padded_feature
.
device
)
for
length
in
feature_lens_after_cnn
],
batch_first
=
True
,
)
padded_feature
=
padded_feature
.
unsqueeze
(
1
)
# Split to chunk to avoid OOM during convolution
padded_embeds
=
[]
for
chunk
in
padded_feature
.
split
(
self
.
conv_chunksize
,
dim
=
0
):
padded_embed
=
F
.
gelu
(
self
.
conv2d1
(
chunk
))
padded_embed
=
F
.
gelu
(
self
.
conv2d2
(
padded_embed
))
padded_embed
=
F
.
gelu
(
self
.
conv2d3
(
padded_embed
))
padded_embeds
.
append
(
padded_embed
)
padded_embed
=
torch
.
cat
(
padded_embeds
,
dim
=
0
)
b
,
c
,
f
,
t
=
padded_embed
.
size
()
padded_embed
=
self
.
conv_out
(
padded_embed
.
permute
(
0
,
3
,
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
)
)
positional_embedding
=
(
self
.
positional_embedding
.
positional_embedding
[:
padded_embed
.
shape
[
1
],
:]
.
unsqueeze
(
0
)
.
to
(
padded_embed
.
dtype
)
)
padded_embed
=
padded_embed
+
positional_embedding
hidden_states
=
padded_embed
[
padded_mask_after_cnn
]
cu_chunk_lens
=
[
0
]
window_aftercnn
=
padded_mask_after_cnn
.
shape
[
-
1
]
*
(
self
.
n_window_infer
//
(
self
.
n_window
*
2
)
)
for
cnn_len
in
aftercnn_lens
:
cu_chunk_lens
+=
[
window_aftercnn
]
*
(
cnn_len
//
window_aftercnn
)
remainder
=
cnn_len
%
window_aftercnn
if
remainder
!=
0
:
cu_chunk_lens
+=
[
remainder
]
cu_seqlens
=
torch
.
tensor
(
cu_chunk_lens
,
device
=
aftercnn_lens
.
device
).
cumsum
(
-
1
,
dtype
=
torch
.
int32
)
for
encoder_layer
in
self
.
layers
:
layer_outputs
=
encoder_layer
(
hidden_states
,
cu_seqlens
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
self
.
ln_post
(
hidden_states
)
hidden_states
=
self
.
proj1
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
proj2
(
hidden_states
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
)
# Ignore copy
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
torch
.
LongTensor
):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths
=
(
input_lengths
-
1
)
//
2
+
1
output_lengths
=
(
input_lengths
-
2
)
//
2
+
1
return
input_lengths
,
output_lengths
class
Qwen3OmniMoeVisionPatchMerger
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
context_dim
:
int
,
spatial_merge_size
:
int
=
2
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_postshuffle_norm
=
False
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
use_postshuffle_norm
=
use_postshuffle_norm
self
.
ln_q
=
RMSNorm
(
self
.
hidden_size
if
use_postshuffle_norm
else
context_dim
,
eps
=
1e-6
)
self
.
mlp
=
nn
.
ModuleList
(
[
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp.0"
,
prefix
),
),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp.2"
,
prefix
),
),
]
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
(
x
.
view
(
-
1
,
self
.
hidden_size
)
if
self
.
use_postshuffle_norm
else
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
)
hidden
=
self
.
ln_q
(
x
).
view
(
-
1
,
self
.
hidden_size
)
for
layer
in
self
.
mlp
:
if
isinstance
(
hidden
,
tuple
):
hidden
=
hidden
[
0
]
hidden
=
layer
(
hidden
)
if
isinstance
(
hidden
,
tuple
):
hidden
=
hidden
[
0
]
return
hidden
class
Qwen3OmniMoeVisionEncoder
(
Qwen3VLMoeVisionModel
):
config
:
Qwen3OmniMoeVisionEncoderConfig
def
__init__
(
self
,
config
:
Qwen3OmniMoeVisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
None
,
**
kwargs
,
):
super
().
__init__
(
vision_config
=
config
,
quant_config
=
quant_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
)
self
.
merger
=
Qwen3OmniMoeVisionPatchMerger
(
dim
=
config
.
out_hidden_size
,
context_dim
=
config
.
hidden_size
,
spatial_merge_size
=
config
.
spatial_merge_size
,
quant_config
=
quant_config
,
use_postshuffle_norm
=
False
,
prefix
=
add_prefix
(
"merger"
,
prefix
),
)
self
.
merger_list
=
nn
.
ModuleList
(
[
Qwen3OmniMoeVisionPatchMerger
(
dim
=
config
.
out_hidden_size
,
context_dim
=
config
.
hidden_size
,
spatial_merge_size
=
config
.
spatial_merge_size
,
use_postshuffle_norm
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"merger_list"
,
prefix
),
)
for
_
in
range
(
len
(
config
.
deepstack_visual_indexes
))
]
)
del
self
.
deepstack_merger_list
@
property
def
deepstack_merger_list
(
self
):
return
self
.
merger_list
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
patch_embed
.
proj
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
proj
.
weight
.
device
class
Qwen3OmniMoeThinkerForConditionalGeneration
(
Qwen3VLMoeForConditionalGeneration
):
config
:
Qwen3OmniMoeThinkerConfig
def
__init__
(
self
,
config
:
Qwen3OmniMoeThinkerConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
(
config
,
quant_config
,
prefix
,
language_model_cls
=
Qwen3MoeLLMModel
)
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
config
.
audio_config
)
self
.
visual
=
Qwen3OmniMoeVisionEncoder
(
config
.
vision_config
,
quant_config
=
quant_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
self
.
pad_token_id
=
(
self
.
config
.
pad_token_id
if
self
.
config
.
pad_token_id
is
not
None
else
-
1
)
def
get_audio_feature
(
self
,
items
:
List
[
MultimodalDataItem
]):
feature_attention_mask
=
torch
.
cat
(
[
item
.
feature_attention_mask
for
item
in
items
],
dim
=
0
).
type
(
torch
.
long
)
input_features
=
(
torch
.
cat
([
item
.
feature
for
item
in
items
])
.
type
(
self
.
audio_tower
.
dtype
)
.
to
(
next
(
self
.
audio_tower
.
parameters
()).
device
)
)
if
feature_attention_mask
is
not
None
:
audio_feature_lengths
=
torch
.
sum
(
feature_attention_mask
,
dim
=
1
)
input_features
=
input_features
.
permute
(
0
,
2
,
1
)[
feature_attention_mask
.
bool
()
].
permute
(
1
,
0
)
else
:
audio_feature_lengths
=
None
feature_lens
=
(
audio_feature_lengths
if
audio_feature_lengths
is
not
None
else
feature_attention_mask
.
sum
(
-
1
)
)
audio_outputs
=
self
.
audio_tower
(
input_features
,
feature_lens
=
feature_lens
,
)
audio_features
=
audio_outputs
.
last_hidden_state
return
audio_features
class
Qwen3OmniMoeForConditionalGeneration
(
PreTrainedModel
):
def
__init__
(
self
,
config
:
Qwen3VLMoeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
(
config
)
self
.
config
=
config
self
.
thinker
=
Qwen3OmniMoeThinkerForConditionalGeneration
(
config
.
thinker_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
enable_talker
=
False
self
.
pad_input_ids
=
self
.
thinker
.
pad_input_ids
self
.
forward
=
self
.
thinker
.
forward
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes
=
(
".bias"
,
"_bias"
,
".k_scale"
,
"_k_scale"
,
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
"_weight_scale"
,
".input_scale"
,
"_input_scale"
,
)
is_fused_expert
=
False
fused_expert_params_mapping
=
[
(
"experts.w13_weight"
,
"experts.gate_up_proj"
,
0
,
"w1"
),
(
"experts.w2_weight"
,
"experts.down_proj"
,
0
,
"w2"
),
]
num_experts
=
self
.
config
.
num_experts
# Cache params_dict to avoid repeated expensive traversal of model parameters
if
not
hasattr
(
self
,
"_cached_params_dict"
):
self
.
_cached_params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
self
.
_cached_params_dict
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
r
"model.language_model."
,
r
"model."
)
if
(
"talker"
in
name
or
"code2wav"
in
name
)
and
not
self
.
enable_talker
:
continue
name
=
name
.
replace
(
".self_attn.out_proj"
,
".self_attn.proj"
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
is_fused_expert
=
True
expert_params_mapping
=
fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
if
"visual"
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra parameters for GPTQ/modelopt models.
if
name
.
endswith
(
ignore_suffixes
)
and
name
not
in
params_dict
:
continue
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
# if is_pp_missing_parameter(name, self):
# continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Track if this is an expert weight to enable early skipping
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
if
"visual"
in
name
or
"audio_tower"
in
name
:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight
=
True
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
if
is_fused_expert
:
loaded_weight
=
loaded_weight
.
transpose
(
-
1
,
-
2
)
# no bias
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
2
)
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
0
],
"w1"
,
num_experts
,
)
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
1
],
"w3"
,
num_experts
,
)
else
:
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
,
shard_id
,
num_experts
,
)
else
:
# Skip loading extra parameters for GPTQ/modelopt models.
if
(
name_mapped
.
endswith
(
ignore_suffixes
)
and
name_mapped
not
in
params_dict
):
continue
param
=
params_dict
[
name_mapped
]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# # other available replicas.
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name_mapped
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
name
=
name_mapped
break
else
:
if
is_expert_weight
:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if
"visual"
in
name
or
"audio_tower"
in
name
:
# adapt to VisionAttention
name
=
name
.
replace
(
r
"attn.qkv."
,
r
"attn.qkv_proj."
)
name
=
name
.
replace
(
r
"model.visual."
,
r
"visual."
)
name
=
name
.
replace
(
r
"attn.out_proj."
,
r
"attn.proj."
)
# Skip loading extra parameters for GPTQ/modelopt models.
if
name
.
endswith
(
ignore_suffixes
)
and
name
not
in
params_dict
:
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Loaded weight with
{
name
=
}
not found in params_dict"
)
EntryClass
=
Qwen3OmniMoeForConditionalGeneration
python/sglang/srt/models/qwen3_vl.py
View file @
86b04d25
...
...
@@ -15,7 +15,7 @@
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import
logging
from
functools
import
lru_cache
,
partial
from
typing
import
Callable
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
Callable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding
,
)
from
sglang.srt.configs.qwen3_vl
import
Qwen3VLConfig
,
Qwen3VLVisionConfig
from
sglang.srt.configs.qwen3_vl
import
(
Qwen3VLConfig
,
Qwen3VLTextConfig
,
Qwen3VLVisionConfig
,
)
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
...
@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_vl
import
Qwen2VLVideoInputs
from
sglang.srt.models.qwen3
import
Qwen3Model
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
logger
=
logging
.
getLogger
(
__name__
)
# === Vision Encoder === #
...
...
@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
return
x
class
Qwen3
_
VisionPatchMerger
(
nn
.
Module
):
class
Qwen3
VLMoe
VisionPatchMerger
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
return
out
class
Qwen3
_
Vision
Transformer
(
nn
.
Module
):
class
Qwen3
VLMoe
Vision
Model
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
self
.
spatial_merge_size
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_unit
=
self
.
spatial_merge_size
**
2
self
.
temporal_patch_size
=
vision_config
.
temporal_patch_size
# layer indexes of which layer's output should be deep-stacked
self
.
deepstack_visual_indexes
=
vision_config
.
deepstack_visual_indexes
self
.
patch_embed
=
Qwen3VLVisionPatchEmbed
(
config
=
vision_config
)
self
.
pos_embed
=
nn
.
Embedding
(
self
.
num_position_embeddings
,
self
.
hidden_size
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
...
...
@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
for
layer_idx
in
range
(
vision_config
.
depth
)
]
)
self
.
merger
=
Qwen3
_
VisionPatchMerger
(
self
.
merger
=
Qwen3
VLMoe
VisionPatchMerger
(
dim
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
norm_layer
=
norm_layer
,
...
...
@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
self
.
deepstack_merger_list
=
nn
.
ModuleList
(
[
Qwen3
_
VisionPatchMerger
(
Qwen3
VLMoe
VisionPatchMerger
(
dim
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
spatial_merge_size
=
self
.
spatial_merge_size
,
...
...
@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
]
)
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x
=
x
.
unsqueeze
(
1
)
deepstack_feature_lists
=
[]
...
...
@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
config
:
Qwen3VLConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
language_model_cls
=
Qwen3LLMModel
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
visual
=
Qwen3_VisionTransformer
(
self
.
visual
=
Qwen3VLMoeVisionModel
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config
=
quant_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
self
.
model
=
Qwen3LLMModel
(
config
=
config
,
# TODO: make it more elegant
if
language_model_cls
is
Qwen3LLMModel
:
self
.
config
:
Qwen3VLConfig
=
config
# for qwen3-vl
else
:
self
.
config
=
config
.
text_config
# for qwen3-omni
self
.
model
=
language_model_cls
(
config
=
self
.
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
if
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
...
...
@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
# deepstack
self
.
deepstack_visual_indexes
=
self
.
visual
.
deepstack_visual_indexes
self
.
num_deepstack_embeddings
=
len
(
self
.
deepstack_visual_indexes
)
@
property
def
use_deepstack
(
self
)
->
bool
:
return
hasattr
(
self
,
"deepstack_visual_indexes"
)
self
.
use_deepstack
=
{
Modality
.
IMAGE
:
True
,
Modality
.
VIDEO
:
True
}
def
separate_deepstack_embeds
(
self
,
embedding
):
assert
(
...
...
python/sglang/srt/models/qwen3_vl_moe.py
View file @
86b04d25
...
...
@@ -14,29 +14,19 @@
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import
logging
from
functools
import
lru_cache
,
partial
from
typing
import
Callable
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
functools
import
lru_cache
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
BatchFeature
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
import
(
Qwen2_5_VisionRotaryEmbedding
,
)
from
sglang.srt.configs.qwen3_vl
import
Qwen3VLMoeConfig
,
Qwen3VLMoe
Vision
Config
from
sglang.srt.configs.qwen3_vl
import
Qwen3VLMoeConfig
,
Qwen3VLMoe
Text
Config
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_world_size
,
get_pp_group
,
get_tensor_model_parallel_rank
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
general_mm_embed_routine
...
...
@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeModel
from
sglang.srt.models.qwen3_vl
import
(
Qwen3_VisionTransformer
,
Qwen3VLForConditionalGeneration
,
)
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.models.qwen3_vl
import
Qwen3VLForConditionalGeneration
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def
__init__
(
self
,
*
,
config
:
Qwen3VLMoeConfig
,
config
:
Qwen3VLMoe
Text
Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
hidden_size
=
config
.
hidden_size
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
embed_tokens
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thw
=
torch
.
concat
([
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thw
.
dim
()
==
2
,
image_grid_thw
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
)
# process deepstack
if
input_deepstack_embeds
is
not
None
and
layer_idx
in
range
(
3
)
:
if
input_deepstack_embeds
is
not
None
and
layer_idx
<
3
:
sep
=
self
.
hidden_size
*
layer_idx
hidden_states
.
add_
(
input_deepstack_embeds
[:,
sep
:
sep
+
self
.
hidden_size
]
...
...
@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
return
hidden_states
,
aux_hidden_states
class
Qwen3VLMoeForConditionalGeneration
(
Qwen3VLForConditionalGeneration
):
def
__init__
(
self
,
*
,
config
:
Qwen3VLMoeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
(
Qwen3VLForConditionalGeneration
,
self
).
__init__
()
self
.
config
=
config
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
self
.
model
=
Qwen3MoeLLMModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
def
load_fused_expert_weights
(
name
:
str
,
params_dict
:
dict
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
str
,
num_experts
:
int
,
):
param
=
params_dict
[
name
]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader
=
param
.
weight_loader
ep_rank
=
get_tensor_model_parallel_rank
()
ep_size
=
get_moe_expert_parallel_world_size
()
if
ep_size
==
1
:
for
expert_id
in
range
(
num_experts
):
curr_expert_weight
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
curr_expert_weight
,
name
,
shard_id
,
expert_id
,
)
self
.
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# deepstack
self
.
deepstack_visual_indexes
=
self
.
visual
.
deepstack_visual_indexes
self
.
num_deepstack_embeddings
=
len
(
self
.
deepstack_visual_indexes
)
@
property
def
use_deepstack
(
self
)
->
bool
:
return
hasattr
(
self
,
"deepstack_visual_indexes"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if
self
.
is_mrope_enabled
:
positions
=
forward_batch
.
mrope_positions
if
not
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
if
self
.
is_mrope_enabled
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
multimodal_model
=
self
,
positions
=
positions
,
use_deepstack
=
self
.
use_deepstack
,
else
:
experts_per_ep
=
num_experts
//
ep_size
start_expert
=
ep_rank
*
experts_per_ep
end_expert
=
(
(
ep_rank
+
1
)
*
experts_per_ep
if
ep_rank
!=
ep_size
-
1
else
num_experts
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
for
idx
,
expert_id
in
enumerate
(
range
(
start_expert
,
end_expert
)):
curr_expert_weight
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
curr_expert_weight
,
name
,
shard_id
,
idx
,
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
True
def
load_fused_expert_weights
(
class
Qwen3VLMoeForConditionalGeneration
(
Qwen3VLForConditionalGeneration
):
def
__init__
(
self
,
name
:
str
,
params_dict
:
dict
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
str
,
num_experts
:
int
,
config
:
Qwen3VLMoeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
language_model_cls
=
Qwen3MoeLLMModel
,
):
param
=
params_dict
[
name
]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader
=
param
.
weight_loader
ep_rank
=
get_tensor_model_parallel_rank
()
ep_size
=
get_moe_expert_parallel_world_size
()
if
ep_size
==
1
:
for
expert_id
in
range
(
num_experts
):
curr_expert_weight
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
curr_expert_weight
,
name
,
shard_id
,
expert_id
,
)
else
:
experts_per_ep
=
num_experts
//
ep_size
start_expert
=
ep_rank
*
experts_per_ep
end_expert
=
(
(
ep_rank
+
1
)
*
experts_per_ep
if
ep_rank
!=
ep_size
-
1
else
num_experts
)
for
idx
,
expert_id
in
enumerate
(
range
(
start_expert
,
end_expert
)):
curr_expert_weight
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
curr_expert_weight
,
name
,
shard_id
,
idx
,
)
return
True
super
().
__init__
(
config
,
quant_config
,
prefix
,
language_model_cls
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self
.
_cached_params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
self
.
_cached_params_dict
for
name
,
loaded_weight
in
weights
:
if
"language_model"
in
name
:
name
=
name
.
replace
(
r
"model.language_model."
,
r
"model."
)
name
=
name
.
replace
(
r
"model.language_model."
,
r
"model."
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
...
...
@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
loaded_weight
=
loaded_weight
.
transpose
(
-
1
,
-
2
)
# no bias
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
2
)
self
.
load_fused_expert_weights
(
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
0
],
"w1"
,
num_experts
,
)
self
.
load_fused_expert_weights
(
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
1
],
...
...
@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
num_experts
,
)
else
:
self
.
load_fused_expert_weights
(
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
,
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
86b04d25
...
...
@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
arch
=
hf_config
.
architectures
[
0
]
self
.
server_args
=
server_args
self
.
transport_mode
=
transport_mode
...
...
@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
"input_features"
:
Modality
.
AUDIO
,
"input_features_mask"
:
Modality
.
AUDIO
,
"audio_attention_mask"
:
Modality
.
AUDIO
,
"feature_attention_mask"
:
Modality
.
AUDIO
,
# Video-related attributes
"pixel_values_videos"
:
Modality
.
VIDEO
,
"second_per_grid_ts"
:
Modality
.
VIDEO
,
...
...
@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
if
self
.
_processor
.
__class__
.
__name__
in
{
"Gemma3nProcessor"
,
"Qwen2AudioProcessor"
,
"Qwen3OmniMoeProcessor"
,
}:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs
[
"audio"
]
=
audios
...
...
python/sglang/srt/multimodal/processors/qwen_vl.py
View file @
86b04d25
...
...
@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
from
sglang.srt.models.qwen3_omni_moe
import
Qwen3OmniMoeForConditionalGeneration
from
sglang.srt.models.qwen3_vl
import
Qwen3VLForConditionalGeneration
from
sglang.srt.models.qwen3_vl_moe
import
Qwen3VLMoeForConditionalGeneration
from
sglang.srt.multimodal.processors.base_processor
import
(
...
...
@@ -209,22 +210,31 @@ async def preprocess_video(
return
video
# Compatible with Qwen
2
VL
and
Qwen
2_5VL
class
Qwen
2_5
VLImageProcessor
(
SGLangBaseProcessor
):
# Compatible with Qwen
-
VL
&
Qwen
-Omni Series
class
QwenVLImageProcessor
(
SGLangBaseProcessor
):
models
=
[
Qwen2VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
,
Qwen3VLForConditionalGeneration
,
Qwen3VLMoeForConditionalGeneration
,
Qwen3OmniMoeForConditionalGeneration
,
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
self
.
model_type
=
hf_config
.
model_type
if
hf_config
.
model_type
==
"qwen3_omni_moe"
:
hf_config
=
hf_config
.
thinker_config
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
# The regex that matches expanded image tokens.
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
audio_start_token_id
=
getattr
(
hf_config
,
"audio_start_token_id"
,
None
)
self
.
audio_token_id
=
getattr
(
hf_config
,
"audio_token_id"
,
None
)
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
IMAGE_FACTOR
=
28
self
.
MIN_PIXELS
=
4
*
28
*
28
...
...
@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
mm_tokens
=
MultimodalSpecialTokens
(
image_token
=
"<|vision_start|><|image_pad|><|vision_end|>"
,
image_token_id
=
hf_config
.
image_token_id
,
# The regex that matches expanded image tokens.
image_token_regex
=
re
.
compile
(
r
"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
),
video_token_id
=
hf_config
.
video_token_id
,
audio_token_id
=
self
.
audio_token_id
,
).
build
(
_processor
)
async
def
process_mm_data_async
(
...
...
@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*
args
,
**
kwargs
,
):
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
image_data
=
image_data
,
video_data
=
request_obj
.
video_data
,
audio_data
=
request_obj
.
audio_data
,
multimodal_tokens
=
self
.
mm_tokens
,
)
...
...
@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
base_output
,
self
.
mm_tokens
)
audio_feature_lengths
=
None
if
self
.
model_type
==
"qwen3_omni_moe"
:
audio_item
=
next
((
mm
for
mm
in
mm_items
if
mm
.
is_audio
()),
None
)
if
audio_item
:
audio_feature_lengths
=
torch
.
sum
(
audio_item
.
feature_attention_mask
,
dim
=
1
)
second_per_grid_ts
=
getattr
(
ret
,
"second_per_grid_ts"
,
None
)
or
getattr
(
ret
,
"video_second_per_grid"
,
None
)
input_ids
=
input_ids
.
flatten
()
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
mm_tokens
.
image_token_id
,
video_token_id
=
self
.
mm_tokens
.
video_token_id
,
vision_start_token_id
=
self
.
vision_start_token_id
,
model_type
=
self
.
hf_config
.
model_type
,
model_type
=
self
.
model_type
,
tokens_per_second
=
getattr
(
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
image_grid_thw
=
getattr
(
ret
,
"image_grid_thw"
,
None
),
video_grid_thw
=
getattr
(
ret
,
"video_grid_thw"
,
None
),
second_per_grid_ts
=
getattr
(
ret
,
"second_per_grid_ts"
,
None
),
second_per_grid_ts
=
second_per_grid_ts
,
use_audio_in_video
=
False
,
audio_seqlens
=
audio_feature_lengths
,
audio_token_id
=
getattr
(
self
.
hf_config
,
"audio_token_id"
,
None
),
audio_start_token_id
=
self
.
audio_start_token_id
,
position_id_per_seconds
=
getattr
(
self
.
hf_config
,
"position_id_per_seconds"
,
None
),
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
...
...
@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
mm_tokens
.
image_token_id
,
"video_token_id"
:
self
.
mm_tokens
.
video_token_id
,
"audio_token_id"
:
self
.
mm_tokens
.
audio_token_id
,
"mrope_positions"
:
mrope_positions
,
"mrope_position_delta"
:
mrope_position_delta
,
}
test/srt/test_vision_openai_server_a.py
View file @
86b04d25
...
...
@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
if
__name__
==
"__main__"
:
del
(
TestOpenAI
Omni
ServerBase
,
TestOpenAI
MLLM
ServerBase
,
ImageOpenAITestMixin
,
VideoOpenAITestMixin
,
AudioOpenAITestMixin
,
OmniOpenAITestMixin
,
)
unittest
.
main
()
test/srt/test_vision_openai_server_b.py
View file @
86b04d25
...
...
@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
cls
.
base_url
+=
"/v1"
class
TestQwen3OmniServer
(
OmniOpenAITestMixin
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
# workaround to fit into H100
"--trust-remote-code"
,
"--mem-fraction-static"
,
"0.90"
,
"--disable-cuda-graph"
,
"--disable-fast-image-processor"
,
"--grammar-backend"
,
"none"
,
],
)
cls
.
base_url
+=
"/v1"
if
__name__
==
"__main__"
:
del
(
TestOpenAI
Omni
ServerBase
,
TestOpenAI
MLLM
ServerBase
,
ImageOpenAITestMixin
,
VideoOpenAITestMixin
,
AudioOpenAITestMixin
,
OmniOpenAITestMixin
,
)
unittest
.
main
()
test/srt/test_vision_openai_server_common.py
View file @
86b04d25
import
base64
import
io
import
os
from
concurrent.futures
import
ThreadPoolExecutor
import
numpy
as
np
import
openai
...
...
@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL
=
"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class
TestOpenAI
Omni
ServerBase
(
CustomTestCase
):
class
TestOpenAI
MLLM
ServerBase
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
""
...
...
@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase):
return
file_path
class
AudioOpenAITestMixin
(
TestOpenAIOmniServerBase
):
class
AudioOpenAITestMixin
(
TestOpenAIMLLMServerBase
):
def
verify_speech_recognition_response
(
self
,
text
):
check_list
=
[
"thank you"
,
"it's a privilege to be here"
,
"leader"
,
"science"
,
"art"
,
]
for
check_word
in
check_list
:
assert
(
check_word
in
text
.
lower
()
),
f
"audio_response: |
{
text
}
| should contain |
{
check_word
}
|"
def
prepare_audio_messages
(
self
,
prompt
,
audio_file_name
):
messages
=
[
{
...
...
@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
"Listen to this audio and write down the audio transcription in English."
,
category
=
"speech"
,
)
check_list
=
[
"thank you"
,
"it's a privilege to be here"
,
"leader"
,
"science"
,
"art"
,
]
for
check_word
in
check_list
:
assert
(
check_word
in
audio_response
),
f
"audio_response: |
{
audio_response
}
| should contain |
{
check_word
}
|"
self
.
verify_speech_recognition_response
(
audio_response
)
def
test_audio_ambient_completion
(
self
):
# bird song
...
...
@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
assert
"bird"
in
audio_response
class
ImageOpenAITestMixin
(
TestOpenAI
Omni
ServerBase
):
def
test_single_image_chat_completion
(
self
):
class
ImageOpenAITestMixin
(
TestOpenAI
MLLM
ServerBase
):
def
run_decode_with_image
(
self
,
image_id
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
content
=
[]
if
image_id
==
0
:
content
.
append
(
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
}
)
elif
image_id
==
1
:
content
.
append
(
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
IMAGE_SGL_LOGO_URL
},
}
)
else
:
pass
content
.
append
(
{
"type"
:
"text"
,
"text"
:
"Describe this image in a sentence."
,
}
)
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
},
{
"type"
:
"text"
,
"text"
:
"Describe this image in a sentence."
,
},
],
},
{
"role"
:
"user"
,
"content"
:
content
},
],
temperature
=
0
,
**
(
self
.
get_vision_request_kwargs
()),
...
...
@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert
response
.
choices
[
0
].
message
.
role
==
"assistant"
text
=
response
.
choices
[
0
].
message
.
content
assert
isinstance
(
text
,
str
)
def
test_mixed_batch
(
self
):
image_ids
=
[
0
,
1
,
2
]
*
4
with
ThreadPoolExecutor
(
4
)
as
executor
:
list
(
executor
.
map
(
self
.
run_decode_with_image
,
image_ids
))
def
verify_single_image_response
(
self
,
response
):
assert
response
.
choices
[
0
].
message
.
role
==
"assistant"
text
=
response
.
choices
[
0
].
message
.
content
assert
isinstance
(
text
,
str
)
# `driver` is for gemma-3-it
assert
(
"man"
in
text
or
"person"
or
"driver"
in
text
...
...
@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
),
f
"text:
{
text
}
, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging`
assert
(
"iron"
in
text
or
"hang"
in
text
or
"cloth"
in
text
or
"coat"
in
text
or
"holding"
in
text
or
"outfit"
in
text
),
f
"text:
{
text
}
, should contain iron, hang, cloth, coat or holding or outfit"
"iron"
in
text
or
"hang"
in
text
or
"cloth"
in
text
or
"holding"
in
text
),
f
"text:
{
text
}
, should contain iron, hang, cloth or holding"
assert
response
.
id
assert
response
.
created
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
test_single_image_chat_completion
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
},
{
"type"
:
"text"
,
"text"
:
"Describe this image in a sentence."
,
},
],
},
],
temperature
=
0
,
**
(
self
.
get_vision_request_kwargs
()),
)
print
(
"-"
*
30
)
print
(
f
"Single image response:
\n
{
response
.
choices
[
0
].
message
.
content
}
"
)
print
(
"-"
*
30
)
self
.
verify_single_image_response
(
response
)
def
test_multi_turn_chat_completion
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
...
@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
},
{
"type"
:
"text"
,
"text"
:
"I have two very different images. They are not related at all. "
"Please describe the first image in one sentence, and then describe the second image in another sentence."
,
"text"
:
"I have two very different images. Please describe them."
,
},
],
},
...
...
@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
_test_mixed_image_audio_chat_completion
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
AUDIO_TRUMP_SPEECH_URL
},
},
{
"type"
:
"text"
,
"text"
:
"Please describe the image in one sentence, and then write down the audio transcription in English."
,
},
],
},
],
temperature
=
0
,
**
(
self
.
get_vision_request_kwargs
()),
)
assert
response
.
choices
[
0
].
message
.
role
==
"assistant"
text
=
response
.
choices
[
0
].
message
.
content
assert
isinstance
(
text
,
str
)
print
(
"-"
*
30
)
print
(
f
"Mixed image & audio response:
\n
{
text
}
"
)
print
(
"-"
*
30
)
assert
(
"man"
in
text
or
"cab"
in
text
or
"SUV"
in
text
or
"taxi"
in
text
or
"car"
in
text
),
f
"text:
{
text
}
, should contain man, cab, SUV, taxi or car"
check_list
=
[
"thank you"
,
"it's a privilege to be here"
,
"leader"
,
"science"
,
"art"
,
]
for
check_word
in
check_list
:
assert
(
check_word
in
text
),
f
"text: |
{
text
}
| should contain |
{
check_word
}
|"
assert
response
.
id
assert
response
.
created
assert
response
.
usage
.
prompt_tokens
>
0
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
prepare_video_images_messages
(
self
,
video_path
):
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
# the size of the video embeds differs from the `modality` argument when preprocessed
...
...
@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
self
.
assertGreater
(
len
(
video_response
),
0
)
class
VideoOpenAITestMixin
(
TestOpenAI
Omni
ServerBase
):
class
VideoOpenAITestMixin
(
TestOpenAI
MLLM
ServerBase
):
def
prepare_video_messages
(
self
,
video_path
):
messages
=
[
{
...
...
@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
),
f
"video_response:
{
video_response
}
, should contain 'black' or 'dark'"
self
.
assertIsNotNone
(
video_response
)
self
.
assertGreater
(
len
(
video_response
),
0
)
class
OmniOpenAITestMixin
(
ImageOpenAITestMixin
,
VideoOpenAITestMixin
,
AudioOpenAITestMixin
):
def
test_mixed_modality_chat_completion
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
AUDIO_TRUMP_SPEECH_URL
},
},
{
"type"
:
"text"
,
"text"
:
"I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact"
,
},
],
},
]
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
messages
,
temperature
=
0
,
max_tokens
=
128
,
stream
=
False
,
)
text
=
response
.
choices
[
0
].
message
.
content
print
(
"-"
*
30
)
print
(
f
"Mixed modality response:
\n
{
text
}
"
)
print
(
"-"
*
30
)
self
.
verify_single_image_response
(
response
=
response
)
self
.
verify_speech_recognition_response
(
text
=
text
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment