Unverified Commit 86b04d25 authored by Mick's avatar Mick Committed by GitHub
Browse files

model: qwen3-omni (thinker-only) (#10911)


Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 85ebeecf
...@@ -853,6 +853,7 @@ multimodal_model_archs = [ ...@@ -853,6 +853,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"InternVLChatModel", "InternVLChatModel",
"InternS1ForConditionalGeneration", "InternS1ForConditionalGeneration",
......
from transformers import PretrainedConfig
from transformers.configuration_utils import layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from sglang.utils import logger
class Qwen3OmniMoeAudioEncoderConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_audio_encoder"
def __init__(
self,
num_mel_bins=128,
encoder_layers=32,
encoder_attention_heads=20,
encoder_ffn_dim=5120,
d_model=1280,
dropout=0,
attention_dropout=0,
activation_function="gelu",
activation_dropout=0,
scale_embedding=False,
initializer_range=0.02,
max_source_positions=1500,
n_window=100,
output_dim=3584,
n_window_infer=400,
conv_chunksize=500,
downsample_hidden_size=480,
**kwargs,
):
super().__init__(**kwargs)
self.num_mel_bins = num_mel_bins
self.d_model = d_model
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.num_hidden_layers = encoder_layers
self.initializer_range = initializer_range
self.scale_embedding = (
scale_embedding # scale factor will be sqrt(d_model) if True
)
self.max_source_positions = max_source_positions
self.n_window = n_window
self.output_dim = output_dim
self.n_window_infer = n_window_infer
self.conv_chunksize = conv_chunksize
self.downsample_hidden_size = downsample_hidden_size
class Qwen3OmniMoeVisionEncoderConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_vision_encoder"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
deepstack_visual_indexes=[8, 16, 24],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = deepstack_visual_indexes
class Qwen3OmniMoeTextConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3OmniMoeText`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=3584,
hidden_size=2048,
intermediate_size=18944,
num_hidden_layers=28,
num_attention_heads=28,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
rope_scaling=None,
attention_bias=False,
sliding_window=None,
attention_dropout=0,
decoder_sparse_step=1,
moe_intermediate_size=768,
num_experts_per_tok=8,
num_experts=128,
norm_topk_prob=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
class Qwen3OmniMoeThinkerConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_thinker"
attribute_map = {
"image_token_id": "image_token_index",
"video_token_id": "video_token_index",
"audio_token_id": "audio_token_index",
}
sub_configs = {
"audio_config": Qwen3OmniMoeAudioEncoderConfig,
"vision_config": Qwen3OmniMoeVisionEncoderConfig,
"text_config": Qwen3OmniMoeTextConfig,
}
def __init__(
self,
audio_config=None,
vision_config=None,
text_config=None,
audio_token_id=151646,
image_token_id=151655,
video_token_id=151656,
position_id_per_seconds=25,
audio_start_token_id=151647,
user_token_id=872,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.user_token_id = user_token_id
self.position_id_per_seconds = position_id_per_seconds
self.audio_start_token_id = audio_start_token_id
self.initializer_range = initializer_range
if isinstance(vision_config, dict):
vision_config = Qwen3OmniMoeVisionEncoderConfig(**vision_config)
elif vision_config is None:
vision_config = Qwen3OmniMoeVisionEncoderConfig()
self.vision_config = vision_config
if isinstance(audio_config, dict):
audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)
elif audio_config is None:
audio_config = Qwen3OmniMoeAudioEncoderConfig()
self.audio_config = audio_config
if isinstance(text_config, dict):
text_config = Qwen3OmniMoeTextConfig(**text_config)
elif text_config is None:
text_config = Qwen3OmniMoeTextConfig()
self.text_config = text_config
self.audio_token_id = audio_token_id
self.image_token_id = image_token_id
self.video_token_id = video_token_id
class Qwen3OmniMoeTalkerCodePredictorConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_talker_code_predictor"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=2048,
hidden_size=1024,
intermediate_size=3072,
num_hidden_layers=5,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=0.000001,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000,
rope_scaling=None,
attention_bias=False,
sliding_window=None,
layer_types=None,
attention_dropout=0,
num_code_groups=32,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
(
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
)
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)
self.num_code_groups = num_code_groups
class Qwen3OmniMoeTalkerTextConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_talker_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=3072,
hidden_size=1024,
intermediate_size=2048,
num_hidden_layers=20,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=0.000001,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000,
rope_scaling=None,
attention_bias=False,
sliding_window=None,
attention_dropout=0,
decoder_sparse_step=1,
moe_intermediate_size=384,
num_experts_per_tok=8,
num_experts=128,
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
class Qwen3OmniMoeTalkerConfig(PretrainedConfig):
sub_configs = {
"code_predictor_config": Qwen3OmniMoeTalkerCodePredictorConfig,
"text_config": Qwen3OmniMoeTalkerTextConfig,
}
def __init__(
self,
code_predictor_config=None,
text_config=None,
num_code_groups=32,
thinker_hidden_size=2048,
codec_eos_token_id=4198,
accept_hidden_layer=18,
codec_nothink_id=4203,
codec_think_bos_id=4204,
codec_think_eos_id=4205,
codec_pad_id=4196,
codec_bos_id=4197,
audio_token_id=151646,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
position_id_per_seconds=25,
audio_start_token_id=151669,
speaker_id=None,
**kwargs,
):
super().__init__(**kwargs)
if code_predictor_config is None:
code_predictor_config = {}
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig()
logger.info(
"code_predictor_config is None. Initializing code_predictor_config model with default values"
)
elif isinstance(code_predictor_config, Qwen3OmniMoeTalkerCodePredictorConfig):
self.code_predictor_config = code_predictor_config
else:
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig(
**code_predictor_config
)
if text_config is None:
text_config = {}
self.text_config = Qwen3OmniMoeTalkerTextConfig()
logger.info(
"talker text_config is None. Initializing talker text model with default values"
)
elif isinstance(text_config, Qwen3OmniMoeTalkerTextConfig):
self.text_config = text_config
else:
self.text_config = Qwen3OmniMoeTalkerTextConfig(**text_config)
self.num_code_groups = num_code_groups
self.thinker_hidden_size = thinker_hidden_size
self.codec_eos_token_id = codec_eos_token_id
self.accept_hidden_layer = accept_hidden_layer
self.codec_nothink_id = codec_nothink_id
self.codec_think_bos_id = codec_think_bos_id
self.codec_think_eos_id = codec_think_eos_id
self.codec_pad_id = codec_pad_id
self.codec_bos_id = codec_bos_id
self.audio_token_id = audio_token_id
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.position_id_per_seconds = position_id_per_seconds
self.audio_start_token_id = audio_start_token_id
self.vision_start_token_id = vision_start_token_id
self.speaker_id = speaker_id
class Qwen3OmniMoeCode2WavConfig(PretrainedConfig):
def __init__(
self,
codebook_size=2048,
hidden_size=1024,
max_position_embeddings=8000,
rope_theta=10000,
num_attention_heads=16,
num_key_value_heads=16,
attention_bias=False,
sliding_window=72,
intermediate_size=3072,
hidden_act="silu",
layer_scale_initial_scale=0.01,
rms_norm_eps=1e-5,
num_hidden_layers=8,
num_quantizers=16,
upsample_rates=(8, 5, 4, 3),
upsampling_ratios=(2, 2),
decoder_dim=1536,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.codebook_size = codebook_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.attention_bias = attention_bias
self.sliding_window = sliding_window
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.layer_scale_initial_scale = layer_scale_initial_scale
self.rms_norm_eps = rms_norm_eps
self.num_hidden_layers = num_hidden_layers
self.num_quantizers = num_quantizers
self.upsample_rates = upsample_rates
self.upsampling_ratios = upsampling_ratios
self.decoder_dim = decoder_dim
self.attention_dropout = attention_dropout
@property
def layer_types(self):
"""
All layer in code2wav should be sliding attention
"""
return ["sliding_attention"] * self.num_hidden_layers
class Qwen3OmniMoeConfig(PretrainedConfig):
model_type = "qwen3_omni_moe"
sub_configs = {
"thinker_config": Qwen3OmniMoeThinkerConfig,
"talker_config": Qwen3OmniMoeTalkerConfig,
"code2wav_config": Qwen3OmniMoeCode2WavConfig,
}
def __init__(
self,
thinker_config=None,
talker_config=None,
code2wav_config=None,
enable_audio_output=True,
im_start_token_id=151644,
im_end_token_id=151645,
tts_pad_token_id=151671,
tts_bos_token_id=151672,
tts_eos_token_id=151673,
system_token_id=8948,
user_token_id=872,
assistant_token_id=77091,
**kwargs,
):
super().__init__(**kwargs)
if thinker_config is None:
thinker_config = {}
logger.info(
"thinker_config is None. Initializing thinker model with default values"
)
if talker_config is None:
talker_config = {}
logger.info(
"talker_config is None. Initializing talker model with default values"
)
if code2wav_config is None:
code2wav_config = {}
logger.info(
"code2wav_config is None. Initializing code2wav model with default values"
)
self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)
self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)
self.enable_audio_output = enable_audio_output
self.im_start_token_id = im_start_token_id
self.im_end_token_id = im_end_token_id
self.tts_pad_token_id = tts_pad_token_id
self.tts_bos_token_id = tts_bos_token_id
self.tts_eos_token_id = tts_eos_token_id
self.system_token_id = system_token_id
self.user_token_id = user_token_id
self.assistant_token_id = assistant_token_id
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
from typing import Optional, Union
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation from transformers.modeling_rope_utils import rope_config_validation
...@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig): ...@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
self.vision_start_token_id = vision_start_token_id self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
__all__ = [
"Qwen3VLMoeConfig",
"Qwen3VLMoeVisionConfig",
"Qwen3VLConfig",
"Qwen3VLVisionConfig",
]
...@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[torch.Tensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if model_type == "qwen3_omni_moe":
# For qwen3-omni
return MRotaryEmbedding.get_rope_index_qwen3_omni(
spatial_merge_size,
image_token_id,
video_token_id,
vision_start_token_id,
tokens_per_second,
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
**kwargs,
)
if ( if (
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe") model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
) and video_grid_thw is not None: ) and video_grid_thw is not None:
...@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw, video_grid_thw[:, 0], dim=0 video_grid_thw, video_grid_thw[:, 0], dim=0
) )
video_grid_thw[:, 0] = 1 video_grid_thw[:, 0] = 1
mrope_position_deltas = [] mrope_position_deltas = []
if input_ids is not None and ( if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None image_grid_thw is not None or video_grid_thw is not None
...@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long() time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten() t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"): elif model_type in (
"qwen2_vl",
"qwen3_vl",
"qwen3_vl_moe",
):
t_index = ( t_index = (
torch.arange(llm_grid_t) torch.arange(llm_grid_t)
.view(-1, 1) .view(-1, 1)
...@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
.flatten() .flatten()
) )
else: else:
raise RuntimeError("Unimplemented") raise RuntimeError(f"Unimplemented model type: {model_type}")
h_index = ( h_index = (
torch.arange(llm_grid_h) torch.arange(llm_grid_h)
.view(1, -1, 1) .view(1, -1, 1)
...@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_deltas = max_position_ids + 1 - s mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
@staticmethod
def get_rope_index_qwen3_omni(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
tokens_per_second: Optional[int] = None,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# For qwen3-omni
audio_token_id = kwargs["audio_token_id"]
audio_start_token_id = kwargs["audio_start_token_id"]
position_id_per_seconds = kwargs["position_id_per_seconds"]
use_audio_in_video = kwargs.get("use_audio_in_video", False)
audio_seqlens = kwargs.get("audio_seqlens", None)
second_per_grids = second_per_grid_ts
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=torch.float,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
for i, current_input_ids in enumerate(total_input_ids):
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(
current_input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = current_input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
audio_nums = torch.sum(current_input_ids == audio_start_token_id)
input_tokens = current_input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = (
image_nums,
video_nums,
audio_nums,
)
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
ed_vision_start = (
input_tokens.index(vision_start_token_id, st)
if (
(
image_token_id in input_tokens
or video_token_id in input_tokens
)
and (remain_videos > 0 or remain_images > 0)
)
else len(input_tokens) + 1
)
ed_audio_start = (
input_tokens.index(audio_start_token_id, st)
if (audio_token_id in input_tokens and remain_audios > 0)
else len(input_tokens) + 1
)
min_ed = min(ed_vision_start, ed_audio_start)
text_len = min_ed - st
if text_len != 0:
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += text_len
# Audio in Video
if (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
bos_len, eos_len = 2, 2
else:
bos_len, eos_len = 1, 1
llm_pos_ids_list.append(
torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += bos_len
# Audio Only
if min_ed == ed_audio_start:
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + audio_len + eos_len)
audio_idx += 1
remain_audios -= 1
# Image Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == image_token_id
):
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (
torch.arange(grid_t) * 1 * position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
image_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
image_len = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + image_len + eos_len)
image_idx += 1
remain_images -= 1
# Video Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == video_token_id
):
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + video_len + eos_len)
video_idx += 1
remain_videos -= 1
# Audio in Video
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
video_llm_pos_ids = (
MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
st += int(text_len + bos_len + audio_len + video_len + eos_len)
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
)
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(
[item.float() for item in llm_pos_ids_list], dim=1
).reshape(3, -1)
position_ids[..., i, :] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(current_input_ids)
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
s = input_ids.shape[1]
position_ids = torch.arange(s)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120 # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
@staticmethod @staticmethod
def get_rope_index_glm4v( def get_rope_index_glm4v(
...@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
# For qwen3-omni
@staticmethod
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
# For qwen3-omni
@staticmethod
def _get_llm_pos_ids_for_vision(
st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
):
grid_h = grid_hs[vision_idx] // spatial_merge_size
grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (
torch.arange(grid_h, device=device)
.view(1, -1, 1)
.expand(len(t_index), -1, grid_w)
.flatten()
)
w_index = (
torch.arange(grid_w, device=device)
.view(1, 1, -1)
.expand(len(t_index), grid_h, -1)
.flatten()
)
t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
return llm_pos_ids
class DualChunkRotaryEmbedding(CustomOp): class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention.""" """Rotary positional embedding for Dual Chunk Attention."""
......
...@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa ...@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
input_ids_tensor[input_ids_tensor == token_id] = pad_value input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist() ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids return ret_input_ids
...@@ -507,7 +506,7 @@ def embed_mm_inputs( ...@@ -507,7 +506,7 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: dict[Modality, List[int]] = None, placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False, use_deepstack: Dict[Modality, bool] = {},
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Embed multimodal inputs and integrate them with text token embeddings. Embed multimodal inputs and integrate them with text token embeddings.
...@@ -533,7 +532,9 @@ def embed_mm_inputs( ...@@ -533,7 +532,9 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list: for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], [] # deepstack_embeddings: per-modality
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
# 2. Get multimodal embedding separately # 2. Get multimodal embedding separately
# Try get mm embedding if any # Try get mm embedding if any
for modality in Modality.all(): for modality in Modality.all():
...@@ -549,7 +550,8 @@ def embed_mm_inputs( ...@@ -549,7 +550,8 @@ def embed_mm_inputs(
# "image", "video", etc # "image", "video", etc
modality_id = modality.name.lower() modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None) embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None: if len(items) != 0:
assert embedder is not None, f"no embedding method found for {modality}"
placeholder_tensor = torch.as_tensor( placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items], [item.pad_value for item in items],
device=input_ids.device, device=input_ids.device,
...@@ -580,11 +582,12 @@ def embed_mm_inputs( ...@@ -580,11 +582,12 @@ def embed_mm_inputs(
items_offset_list=items_offsets, items_offset_list=items_offsets,
) )
if use_deepstack and embedding is not None: if use_deepstack.get(modality, None) and embedding is not None:
embedding, deepstack_embedding = ( embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding) multimodal_model.separate_deepstack_embeds(embedding)
) )
deepstack_embeddings += [deepstack_embedding] deepstack_embeddings += [deepstack_embedding]
modalities += [modality]
embeddings += [embedding] embeddings += [embedding]
masks += [mask] masks += [mask]
...@@ -597,17 +600,14 @@ def embed_mm_inputs( ...@@ -597,17 +600,14 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1) input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# deepstack embedding # deepstack embedding
if use_deepstack: if use_deepstack:
num_deepstack_embeddings = ( num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings, inputs_embeds.shape[-1] * num_deepstack_embeddings,
) )
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
input_deepstack_embeds = torch.zeros( input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape, deepstack_embedding_shape,
device=inputs_embeds.device, device=inputs_embeds.device,
...@@ -616,14 +616,16 @@ def embed_mm_inputs( ...@@ -616,14 +616,16 @@ def embed_mm_inputs(
other_info["input_deepstack_embeds"] = input_deepstack_embeds other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks): # 4. scatter embeddings into input embedding
for i, modality, embedding, mask in zip(
range(len(embeddings)), modalities, embeddings, masks
):
if embedding is None or mask is None: if embedding is None or mask is None:
continue continue
# in-place update # in-place update
indices = torch.where(mask.squeeze(dim=-1))[0] indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack.get(modality, None):
if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to( input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype inputs_embeds.device, inputs_embeds.dtype
) )
...@@ -640,7 +642,7 @@ def general_mm_embed_routine( ...@@ -640,7 +642,7 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False, use_deepstack: Dict[Modality, bool] = {},
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -652,7 +654,7 @@ def general_mm_embed_routine( ...@@ -652,7 +654,7 @@ def general_mm_embed_routine(
language_model: Base language model to use language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings use_deepstack: Whether to use deepstack embeddings for each modality, default False
**kwargs: Additional arguments passed to language model **kwargs: Additional arguments passed to language model
Returns: Returns:
......
...@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
if self.mm_processor and obj.contains_mm_input(): if self.mm_processor and obj.contains_mm_input():
if not isinstance(obj.image_data, list) and obj.image_data: if obj.image_data is not None and not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data] obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list) and obj.audio_data: if obj.audio_data is not None and not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data] obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async( mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
......
...@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
......
...@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel): ...@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
config: Qwen3MoeConfig, config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
decoder_layer_type=Qwen3MoeDecoderLayer,
) -> None: ) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__( super().__init__(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer, decoder_layer_type=decoder_layer_type,
alt_stream=alt_stream, alt_stream=alt_stream,
) )
......
# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
from sglang.srt.configs.qwen3_omni import (
Qwen3OmniMoeAudioEncoderConfig,
Qwen3OmniMoeThinkerConfig,
Qwen3OmniMoeVisionEncoderConfig,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel
from sglang.srt.models.qwen3_vl_moe import (
Qwen3MoeLLMModel,
Qwen3VLMoeForConditionalGeneration,
load_fused_expert_weights,
)
from sglang.srt.utils import add_prefix, logger
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
def __init__(
self,
config: Qwen3OmniMoeAudioEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
embed_dim = config.d_model
self.embed_dim = config.d_model
self.self_attn = VisionAttention(
embed_dim=embed_dim,
num_heads=config.encoder_attention_heads,
projection_size=embed_dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend="fa3",
softmax_in_single_precision=False,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
x=hidden_states,
cu_seqlens=cu_seqlens,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
outputs = (hidden_states,)
return outputs
class SinusoidsPositionEmbedding(nn.Module):
def __init__(self, length, channels, max_timescale=10000):
super().__init__()
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2).float()
)
scaled_time = (
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
)
self.register_buffer(
"positional_embedding",
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
persistent=False,
)
def forward(self, seqlen: int):
return self.positional_embedding[:seqlen, :]
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
config: Qwen3OmniMoeAudioEncoderConfig
def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
super().__init__(config)
self.dropout = config.dropout
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.n_window = config.n_window
self.positional_embedding = SinusoidsPositionEmbedding(
self.max_source_positions, embed_dim
)
self.layers = nn.ModuleList(
[
Qwen3OmniMoeAudioEncoderLayer(config)
for _ in range(config.encoder_layers)
]
)
self.ln_post = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d2 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
3,
2,
padding=1,
)
self.conv2d3 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
3,
2,
padding=1,
)
self.conv_out = nn.Linear(
config.downsample_hidden_size
* ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
config.d_model,
bias=False,
)
self.proj1 = nn.Linear(config.d_model, config.d_model)
self.act = ACT2FN[config.activation_function]
self.proj2 = nn.Linear(config.d_model, config.output_dim)
self.n_window_infer = self.config.n_window_infer
self.conv_chunksize = self.config.conv_chunksize
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def forward(
self,
input_features,
feature_lens=None,
aftercnn_lens=None,
):
r"""
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length after cnn
"""
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
chunk_lengths = torch.tensor(
[self.n_window * 2] * chunk_num.sum(),
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(
chunk_list, batch_first=True
).transpose(1, 2)
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
[
torch.ones(length, dtype=torch.bool, device=padded_feature.device)
for length in feature_lens_after_cnn
],
batch_first=True,
)
padded_feature = padded_feature.unsqueeze(1)
# Split to chunk to avoid OOM during convolution
padded_embeds = []
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
padded_embed = F.gelu(self.conv2d1(chunk))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
padded_embeds.append(padded_embed)
padded_embed = torch.cat(padded_embeds, dim=0)
b, c, f, t = padded_embed.size()
padded_embed = self.conv_out(
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
)
positional_embedding = (
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
.unsqueeze(0)
.to(padded_embed.dtype)
)
padded_embed = padded_embed + positional_embedding
hidden_states = padded_embed[padded_mask_after_cnn]
cu_chunk_lens = [0]
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
self.n_window_infer // (self.n_window * 2)
)
for cnn_len in aftercnn_lens:
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
remainder = cnn_len % window_aftercnn
if remainder != 0:
cu_chunk_lens += [remainder]
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
-1, dtype=torch.int32
)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
cu_seqlens,
)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.proj1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.proj2(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
# Ignore copy
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
class Qwen3OmniMoeVisionPatchMerger(nn.Module):
def __init__(
self,
dim: int,
context_dim: int,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_postshuffle_norm=False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
self.ln_q = RMSNorm(
self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
)
self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
),
nn.GELU(),
RowParallelLinear(
self.hidden_size,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (
x.view(-1, self.hidden_size)
if self.use_postshuffle_norm
else x.view(-1, x.shape[-1])
)
hidden = self.ln_q(x).view(-1, self.hidden_size)
for layer in self.mlp:
if isinstance(hidden, tuple):
hidden = hidden[0]
hidden = layer(hidden)
if isinstance(hidden, tuple):
hidden = hidden[0]
return hidden
class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel):
config: Qwen3OmniMoeVisionEncoderConfig
def __init__(
self,
config: Qwen3OmniMoeVisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = None,
**kwargs,
):
super().__init__(
vision_config=config,
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
)
self.merger = Qwen3OmniMoeVisionPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
quant_config=quant_config,
use_postshuffle_norm=False,
prefix=add_prefix("merger", prefix),
)
self.merger_list = nn.ModuleList(
[
Qwen3OmniMoeVisionPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
use_postshuffle_norm=True,
quant_config=quant_config,
prefix=add_prefix("merger_list", prefix),
)
for _ in range(len(config.deepstack_visual_indexes))
]
)
del self.deepstack_merger_list
@property
def deepstack_merger_list(self):
return self.merger_list
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
config: Qwen3OmniMoeThinkerConfig
def __init__(
self,
config: Qwen3OmniMoeThinkerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(
config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
)
self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
self.visual = Qwen3OmniMoeVisionEncoder(
config.vision_config,
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix),
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
def get_audio_feature(self, items: List[MultimodalDataItem]):
feature_attention_mask = torch.cat(
[item.feature_attention_mask for item in items], dim=0
).type(torch.long)
input_features = (
torch.cat([item.feature for item in items])
.type(self.audio_tower.dtype)
.to(next(self.audio_tower.parameters()).device)
)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
input_features = input_features.permute(0, 2, 1)[
feature_attention_mask.bool()
].permute(1, 0)
else:
audio_feature_lengths = None
feature_lens = (
audio_feature_lengths
if audio_feature_lengths is not None
else feature_attention_mask.sum(-1)
)
audio_outputs = self.audio_tower(
input_features,
feature_lens=feature_lens,
)
audio_features = audio_outputs.last_hidden_state
return audio_features
class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel):
def __init__(
self,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config)
self.config = config
self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration(
config.thinker_config, quant_config=quant_config, prefix=prefix
)
self.enable_talker = False
self.pad_input_ids = self.thinker.pad_input_ids
self.forward = self.thinker.forward
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
"_input_scale",
)
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.num_experts
# Cache params_dict to avoid repeated expensive traversal of model parameters
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
name = name.replace(r"model.language_model.", r"model.")
if ("talker" in name or "code2wav" in name) and not self.enable_talker:
continue
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if "visual" in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
if "visual" in name or "audio_tower" in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
"w3",
num_experts,
)
else:
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
else:
# Skip loading extra parameters for GPTQ/modelopt models.
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# # other available replicas.
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
name = name_mapped
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name or "audio_tower" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
name = name.replace(r"model.visual.", r"visual.")
name = name.replace(r"attn.out_proj.", r"attn.proj.")
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(
f"Loaded weight with {name=} not found in params_dict"
)
EntryClass = Qwen3OmniMoeForConditionalGeneration
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging import logging
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( ...@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionRotaryEmbedding,
) )
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig from sglang.srt.configs.qwen3_vl import (
Qwen3VLConfig,
Qwen3VLTextConfig,
Qwen3VLVisionConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import ( ...@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.models.qwen3 import Qwen3Model from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# === Vision Encoder === # # === Vision Encoder === #
...@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
return x return x
class Qwen3_VisionPatchMerger(nn.Module): class Qwen3VLMoeVisionPatchMerger(nn.Module):
def __init__( def __init__(
self, self,
...@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module): ...@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
return out return out
class Qwen3_VisionTransformer(nn.Module): class Qwen3VLMoeVisionModel(nn.Module):
def __init__( def __init__(
self, self,
...@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2 self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size self.temporal_patch_size = vision_config.temporal_patch_size
# layer indexes of which layer's output should be deep-stacked
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config) self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps) norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
...@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
] ]
) )
self.merger = Qwen3_VisionPatchMerger( self.merger = Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size, dim=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
self.deepstack_merger_list = nn.ModuleList( self.deepstack_merger_list = nn.ModuleList(
[ [
Qwen3_VisionPatchMerger( Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size, dim=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size, spatial_merge_size=self.spatial_merge_size,
...@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
] ]
) )
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1) x = x.unsqueeze(1)
deepstack_feature_lists = [] deepstack_feature_lists = []
...@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module): ...@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
config: Qwen3VLConfig, config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
language_model_cls=Qwen3LLMModel,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.visual = Qwen3VLMoeVisionModel(
self.visual = Qwen3_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config, quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix), prefix=add_prefix("visual", prefix),
) )
self.model = Qwen3LLMModel( # TODO: make it more elegant
config=config, if language_model_cls is Qwen3LLMModel:
self.config: Qwen3VLConfig = config # for qwen3-vl
else:
self.config = config.text_config # for qwen3-omni
self.model = language_model_cls(
config=self.config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model", prefix), prefix=add_prefix("model", prefix),
) )
if config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, self.config.vocab_size,
config.hidden_size, self.config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(self.config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
...@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module): ...@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
# deepstack # deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def separate_deepstack_embeds(self, embedding): def separate_deepstack_embeds(self, embedding):
assert ( assert (
......
...@@ -14,29 +14,19 @@ ...@@ -14,29 +14,19 @@
# ============================================================================== # ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging import logging
from functools import lru_cache, partial from functools import lru_cache
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import Iterable, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import general_mm_embed_routine from sglang.srt.managers.mm_utils import general_mm_embed_routine
...@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem ...@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeModel from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import ( from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__( def __init__(
self, self,
*, *,
config: Qwen3VLMoeConfig, config: Qwen3VLMoeTextConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__(config=config, quant_config=quant_config, prefix=prefix) super().__init__(config=config, quant_config=quant_config, prefix=prefix)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens return self.embed_tokens
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
) )
# process deepstack # process deepstack
if input_deepstack_embeds is not None and layer_idx in range(3): if input_deepstack_embeds is not None and layer_idx < 3:
sep = self.hidden_size * layer_idx sep = self.hidden_size * layer_idx
hidden_states.add_( hidden_states.add_(
input_deepstack_embeds[:, sep : sep + self.hidden_size] input_deepstack_embeds[:, sep : sep + self.hidden_size]
...@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
return hidden_states, aux_hidden_states return hidden_states, aux_hidden_states
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): def load_fused_expert_weights(
def __init__( name: str,
self, params_dict: dict,
*, loaded_weight: torch.Tensor,
config: Qwen3VLMoeConfig, shard_id: str,
quant_config: Optional[QuantizationConfig] = None, num_experts: int,
prefix: str = "", ):
): param = params_dict[name]
super(Qwen3VLForConditionalGeneration, self).__init__() # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
self.config = config weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
self.visual = Qwen3_VisionTransformer( ep_size = get_moe_expert_parallel_world_size()
config.vision_config, if ep_size == 1:
norm_eps=getattr(config, "rms_norm_eps", 1e-6), for expert_id in range(num_experts):
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. curr_expert_weight = loaded_weight[expert_id]
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. weight_loader(
quant_config=quant_config, param,
prefix=add_prefix("visual", prefix), curr_expert_weight,
) name,
shard_id,
self.model = Qwen3MoeLLMModel( expert_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling else:
experts_per_ep = num_experts // ep_size
self.logits_processor = LogitsProcessor(config) start_expert = ep_rank * experts_per_ep
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) end_expert = (
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
) )
if not get_embedding: for idx, expert_id in enumerate(range(start_expert, end_expert)):
return self.logits_processor( curr_expert_weight = loaded_weight[expert_id]
input_ids, hidden_states, self.lm_head, forward_batch weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
) )
else: return True
return self.pooler(hidden_states, forward_batch)
def load_fused_expert_weights(
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self, self,
name: str, config: Qwen3VLMoeConfig,
params_dict: dict, quant_config: Optional[QuantizationConfig] = None,
loaded_weight: torch.Tensor, prefix: str = "",
shard_id: str, language_model_cls=Qwen3MoeLLMModel,
num_experts: int,
): ):
param = params_dict[name] super().__init__(config, quant_config, prefix, language_model_cls)
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep
if ep_rank != ep_size - 1
else num_experts
)
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
return True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): ...@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self._cached_params_dict = dict(self.named_parameters()) self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict params_dict = self._cached_params_dict
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "language_model" in name: name = name.replace(r"model.language_model.", r"model.")
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name: if "experts.gate_up_proj" in name or "experts.down_proj" in name:
...@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): ...@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
loaded_weight = loaded_weight.transpose(-1, -2) # no bias loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name: if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2) loaded_weight = loaded_weight.chunk(2, dim=-2)
self.load_fused_expert_weights( load_fused_expert_weights(
name_mapped, name_mapped,
params_dict, params_dict,
loaded_weight[0], loaded_weight[0],
"w1", "w1",
num_experts, num_experts,
) )
self.load_fused_expert_weights( load_fused_expert_weights(
name_mapped, name_mapped,
params_dict, params_dict,
loaded_weight[1], loaded_weight[1],
...@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): ...@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
num_experts, num_experts,
) )
else: else:
self.load_fused_expert_weights( load_fused_expert_weights(
name_mapped, name_mapped,
params_dict, params_dict,
loaded_weight, loaded_weight,
......
...@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
): ):
self.hf_config = hf_config self.hf_config = hf_config
self._processor = _processor self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args self.server_args = server_args
self.transport_mode = transport_mode self.transport_mode = transport_mode
...@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
"input_features": Modality.AUDIO, "input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO, "input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO, "audio_attention_mask": Modality.AUDIO,
"feature_attention_mask": Modality.AUDIO,
# Video-related attributes # Video-related attributes
"pixel_values_videos": Modality.VIDEO, "pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO, "second_per_grid_ts": Modality.VIDEO,
...@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
if self._processor.__class__.__name__ in { if self._processor.__class__.__name__ in {
"Gemma3nProcessor", "Gemma3nProcessor",
"Qwen2AudioProcessor", "Qwen2AudioProcessor",
"Qwen3OmniMoeProcessor",
}: }:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios kwargs["audio"] = audios
......
...@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode ...@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
...@@ -209,22 +210,31 @@ async def preprocess_video( ...@@ -209,22 +210,31 @@ async def preprocess_video(
return video return video
# Compatible with Qwen2VL and Qwen2_5VL # Compatible with Qwen-VL & Qwen-Omni Series
class Qwen2_5VLImageProcessor(SGLangBaseProcessor): class QwenVLImageProcessor(SGLangBaseProcessor):
models = [ models = [
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration, Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeForConditionalGeneration,
Qwen3OmniMoeForConditionalGeneration,
] ]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.model_type = hf_config.model_type
if hf_config.model_type == "qwen3_omni_moe":
hf_config = hf_config.thinker_config
super().__init__(hf_config, server_args, _processor, *args, **kwargs) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.vision_start_token_id = hf_config.vision_start_token_id self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id self.vision_end_token_id = hf_config.vision_end_token_id
self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None)
self.audio_token_id = getattr(hf_config, "audio_token_id", None)
self.NUM_TOKEN_PER_FRAME = 770 self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28 self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28 self.MIN_PIXELS = 4 * 28 * 28
...@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token="<|vision_start|><|image_pad|><|vision_end|>", image_token="<|vision_start|><|image_pad|><|vision_end|>",
image_token_id=hf_config.image_token_id, image_token_id=hf_config.image_token_id,
# The regex that matches expanded image tokens.
image_token_regex=re.compile( image_token_regex=re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
), ),
video_token_id=hf_config.video_token_id, video_token_id=hf_config.video_token_id,
audio_token_id=self.audio_token_id,
).build(_processor) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
...@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
video_data=request_obj.video_data, video_data=request_obj.video_data,
audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
) )
...@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
base_output, self.mm_tokens base_output, self.mm_tokens
) )
audio_feature_lengths = None
if self.model_type == "qwen3_omni_moe":
audio_item = next((mm for mm in mm_items if mm.is_audio()), None)
if audio_item:
audio_feature_lengths = torch.sum(
audio_item.feature_attention_mask, dim=1
)
second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr(
ret, "video_second_per_grid", None
)
input_ids = input_ids.flatten() input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.mm_tokens.image_token_id, image_token_id=self.mm_tokens.image_token_id,
video_token_id=self.mm_tokens.video_token_id, video_token_id=self.mm_tokens.video_token_id,
vision_start_token_id=self.vision_start_token_id, vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type, model_type=self.model_type,
tokens_per_second=getattr( tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None self.hf_config.vision_config, "tokens_per_second", None
), ),
input_ids=input_ids.unsqueeze(0), input_ids=input_ids.unsqueeze(0),
image_grid_thw=getattr(ret, "image_grid_thw", None), image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=getattr(ret, "video_grid_thw", None), video_grid_thw=getattr(ret, "video_grid_thw", None),
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None), second_per_grid_ts=second_per_grid_ts,
use_audio_in_video=False,
audio_seqlens=audio_feature_lengths,
audio_token_id=getattr(self.hf_config, "audio_token_id", None),
audio_start_token_id=self.audio_start_token_id,
position_id_per_seconds=getattr(
self.hf_config, "position_id_per_seconds", None
),
) )
mrope_positions = mrope_positions.squeeze(1) mrope_positions = mrope_positions.squeeze(1)
...@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.mm_tokens.image_token_id, "im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.mm_tokens.video_token_id, "video_token_id": self.mm_tokens.video_token_id,
"audio_token_id": self.mm_tokens.audio_token_id,
"mrope_positions": mrope_positions, "mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta, "mrope_position_delta": mrope_position_delta,
} }
...@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin): ...@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
if __name__ == "__main__": if __name__ == "__main__":
del ( del (
TestOpenAIOmniServerBase, TestOpenAIMLLMServerBase,
ImageOpenAITestMixin, ImageOpenAITestMixin,
VideoOpenAITestMixin, VideoOpenAITestMixin,
AudioOpenAITestMixin, AudioOpenAITestMixin,
OmniOpenAITestMixin,
) )
unittest.main() unittest.main()
...@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin): ...@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestQwen3OmniServer(OmniOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ # workaround to fit into H100
"--trust-remote-code",
"--mem-fraction-static",
"0.90",
"--disable-cuda-graph",
"--disable-fast-image-processor",
"--grammar-backend",
"none",
],
)
cls.base_url += "/v1"
if __name__ == "__main__": if __name__ == "__main__":
del ( del (
TestOpenAIOmniServerBase, TestOpenAIMLLMServerBase,
ImageOpenAITestMixin, ImageOpenAITestMixin,
VideoOpenAITestMixin, VideoOpenAITestMixin,
AudioOpenAITestMixin, AudioOpenAITestMixin,
OmniOpenAITestMixin,
) )
unittest.main() unittest.main()
import base64 import base64
import io import io
import os import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
import openai import openai
...@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test ...@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIOmniServerBase(CustomTestCase): class TestOpenAIMLLMServerBase(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "" cls.model = ""
...@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase): ...@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase):
return file_path return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase): class AudioOpenAITestMixin(TestOpenAIMLLMServerBase):
def verify_speech_recognition_response(self, text):
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text.lower()
), f"audio_response: |{text}| should contain |{check_word}|"
def prepare_audio_messages(self, prompt, audio_file_name): def prepare_audio_messages(self, prompt, audio_file_name):
messages = [ messages = [
{ {
...@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
"Listen to this audio and write down the audio transcription in English.", "Listen to this audio and write down the audio transcription in English.",
category="speech", category="speech",
) )
check_list = [ self.verify_speech_recognition_response(audio_response)
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
def test_audio_ambient_completion(self): def test_audio_ambient_completion(self):
# bird song # bird song
...@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
assert "bird" in audio_response assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase): class ImageOpenAITestMixin(TestOpenAIMLLMServerBase):
def test_single_image_chat_completion(self): def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create( response = client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
{ {"role": "user", "content": content},
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
], ],
temperature=0, temperature=0,
**(self.get_vision_request_kwargs()), **(self.get_vision_request_kwargs()),
...@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content text = response.choices[0].message.content
assert isinstance(text, str) assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def verify_single_image_response(self, response):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
# `driver` is for gemma-3-it # `driver` is for gemma-3-it
assert ( assert (
"man" in text or "person" or "driver" in text "man" in text or "person" or "driver" in text
...@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car" ), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging` # MiniCPMO fails to recognize `iron`, but `hanging`
assert ( assert (
"iron" in text "iron" in text or "hang" in text or "cloth" in text or "holding" in text
or "hang" in text ), f"text: {text}, should contain iron, hang, cloth or holding"
or "cloth" in text
or "coat" in text
or "holding" in text
or "outfit" in text
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
print("-" * 30)
print(f"Single image response:\n{response.choices[0].message.content}")
print("-" * 30)
self.verify_single_image_response(response)
def test_multi_turn_chat_completion(self): def test_multi_turn_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
...@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
}, },
{ {
"type": "text", "type": "text",
"text": "I have two very different images. They are not related at all. " "text": "I have two very different images. Please describe them.",
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
}, },
], ],
}, },
...@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def _test_mixed_image_audio_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "Please describe the image in one sentence, and then write down the audio transcription in English.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
print("-" * 30)
print(f"Mixed image & audio response:\n{text}")
print("-" * 30)
assert (
"man" in text
or "cab" in text
or "SUV" in text
or "taxi" in text
or "car" in text
), f"text: {text}, should contain man, cab, SUV, taxi or car"
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text
), f"text: |{text}| should contain |{check_word}|"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def prepare_video_images_messages(self, video_path): def prepare_video_images_messages(self, video_path):
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
# the size of the video embeds differs from the `modality` argument when preprocessed # the size of the video embeds differs from the `modality` argument when preprocessed
...@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
class VideoOpenAITestMixin(TestOpenAIOmniServerBase): class VideoOpenAITestMixin(TestOpenAIMLLMServerBase):
def prepare_video_messages(self, video_path): def prepare_video_messages(self, video_path):
messages = [ messages = [
{ {
...@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
), f"video_response: {video_response}, should contain 'black' or 'dark'" ), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response) self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
class OmniOpenAITestMixin(
ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin
):
def test_mixed_modality_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact",
},
],
},
]
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
)
text = response.choices[0].message.content
print("-" * 30)
print(f"Mixed modality response:\n{text}")
print("-" * 30)
self.verify_single_image_response(response=response)
self.verify_speech_recognition_response(text=text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment