Unverified Commit 86b04d25 authored by Mick's avatar Mick Committed by GitHub
Browse files

model: qwen3-omni (thinker-only) (#10911)


Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 85ebeecf
......@@ -853,6 +853,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
......
from transformers import PretrainedConfig
from transformers.configuration_utils import layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from sglang.utils import logger
class Qwen3OmniMoeAudioEncoderConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_audio_encoder"
def __init__(
self,
num_mel_bins=128,
encoder_layers=32,
encoder_attention_heads=20,
encoder_ffn_dim=5120,
d_model=1280,
dropout=0,
attention_dropout=0,
activation_function="gelu",
activation_dropout=0,
scale_embedding=False,
initializer_range=0.02,
max_source_positions=1500,
n_window=100,
output_dim=3584,
n_window_infer=400,
conv_chunksize=500,
downsample_hidden_size=480,
**kwargs,
):
super().__init__(**kwargs)
self.num_mel_bins = num_mel_bins
self.d_model = d_model
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.num_hidden_layers = encoder_layers
self.initializer_range = initializer_range
self.scale_embedding = (
scale_embedding # scale factor will be sqrt(d_model) if True
)
self.max_source_positions = max_source_positions
self.n_window = n_window
self.output_dim = output_dim
self.n_window_infer = n_window_infer
self.conv_chunksize = conv_chunksize
self.downsample_hidden_size = downsample_hidden_size
class Qwen3OmniMoeVisionEncoderConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_vision_encoder"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
deepstack_visual_indexes=[8, 16, 24],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = deepstack_visual_indexes
class Qwen3OmniMoeTextConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3OmniMoeText`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=3584,
hidden_size=2048,
intermediate_size=18944,
num_hidden_layers=28,
num_attention_heads=28,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
rope_scaling=None,
attention_bias=False,
sliding_window=None,
attention_dropout=0,
decoder_sparse_step=1,
moe_intermediate_size=768,
num_experts_per_tok=8,
num_experts=128,
norm_topk_prob=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
class Qwen3OmniMoeThinkerConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_thinker"
attribute_map = {
"image_token_id": "image_token_index",
"video_token_id": "video_token_index",
"audio_token_id": "audio_token_index",
}
sub_configs = {
"audio_config": Qwen3OmniMoeAudioEncoderConfig,
"vision_config": Qwen3OmniMoeVisionEncoderConfig,
"text_config": Qwen3OmniMoeTextConfig,
}
def __init__(
self,
audio_config=None,
vision_config=None,
text_config=None,
audio_token_id=151646,
image_token_id=151655,
video_token_id=151656,
position_id_per_seconds=25,
audio_start_token_id=151647,
user_token_id=872,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.user_token_id = user_token_id
self.position_id_per_seconds = position_id_per_seconds
self.audio_start_token_id = audio_start_token_id
self.initializer_range = initializer_range
if isinstance(vision_config, dict):
vision_config = Qwen3OmniMoeVisionEncoderConfig(**vision_config)
elif vision_config is None:
vision_config = Qwen3OmniMoeVisionEncoderConfig()
self.vision_config = vision_config
if isinstance(audio_config, dict):
audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)
elif audio_config is None:
audio_config = Qwen3OmniMoeAudioEncoderConfig()
self.audio_config = audio_config
if isinstance(text_config, dict):
text_config = Qwen3OmniMoeTextConfig(**text_config)
elif text_config is None:
text_config = Qwen3OmniMoeTextConfig()
self.text_config = text_config
self.audio_token_id = audio_token_id
self.image_token_id = image_token_id
self.video_token_id = video_token_id
class Qwen3OmniMoeTalkerCodePredictorConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_talker_code_predictor"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=2048,
hidden_size=1024,
intermediate_size=3072,
num_hidden_layers=5,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=0.000001,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000,
rope_scaling=None,
attention_bias=False,
sliding_window=None,
layer_types=None,
attention_dropout=0,
num_code_groups=32,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
(
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
)
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)
self.num_code_groups = num_code_groups
class Qwen3OmniMoeTalkerTextConfig(PretrainedConfig):
model_type = "qwen3_omni_moe_talker_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=3072,
hidden_size=1024,
intermediate_size=2048,
num_hidden_layers=20,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=0.000001,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000,
rope_scaling=None,
attention_bias=False,
sliding_window=None,
attention_dropout=0,
decoder_sparse_step=1,
moe_intermediate_size=384,
num_experts_per_tok=8,
num_experts=128,
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
class Qwen3OmniMoeTalkerConfig(PretrainedConfig):
sub_configs = {
"code_predictor_config": Qwen3OmniMoeTalkerCodePredictorConfig,
"text_config": Qwen3OmniMoeTalkerTextConfig,
}
def __init__(
self,
code_predictor_config=None,
text_config=None,
num_code_groups=32,
thinker_hidden_size=2048,
codec_eos_token_id=4198,
accept_hidden_layer=18,
codec_nothink_id=4203,
codec_think_bos_id=4204,
codec_think_eos_id=4205,
codec_pad_id=4196,
codec_bos_id=4197,
audio_token_id=151646,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
position_id_per_seconds=25,
audio_start_token_id=151669,
speaker_id=None,
**kwargs,
):
super().__init__(**kwargs)
if code_predictor_config is None:
code_predictor_config = {}
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig()
logger.info(
"code_predictor_config is None. Initializing code_predictor_config model with default values"
)
elif isinstance(code_predictor_config, Qwen3OmniMoeTalkerCodePredictorConfig):
self.code_predictor_config = code_predictor_config
else:
self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig(
**code_predictor_config
)
if text_config is None:
text_config = {}
self.text_config = Qwen3OmniMoeTalkerTextConfig()
logger.info(
"talker text_config is None. Initializing talker text model with default values"
)
elif isinstance(text_config, Qwen3OmniMoeTalkerTextConfig):
self.text_config = text_config
else:
self.text_config = Qwen3OmniMoeTalkerTextConfig(**text_config)
self.num_code_groups = num_code_groups
self.thinker_hidden_size = thinker_hidden_size
self.codec_eos_token_id = codec_eos_token_id
self.accept_hidden_layer = accept_hidden_layer
self.codec_nothink_id = codec_nothink_id
self.codec_think_bos_id = codec_think_bos_id
self.codec_think_eos_id = codec_think_eos_id
self.codec_pad_id = codec_pad_id
self.codec_bos_id = codec_bos_id
self.audio_token_id = audio_token_id
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.position_id_per_seconds = position_id_per_seconds
self.audio_start_token_id = audio_start_token_id
self.vision_start_token_id = vision_start_token_id
self.speaker_id = speaker_id
class Qwen3OmniMoeCode2WavConfig(PretrainedConfig):
def __init__(
self,
codebook_size=2048,
hidden_size=1024,
max_position_embeddings=8000,
rope_theta=10000,
num_attention_heads=16,
num_key_value_heads=16,
attention_bias=False,
sliding_window=72,
intermediate_size=3072,
hidden_act="silu",
layer_scale_initial_scale=0.01,
rms_norm_eps=1e-5,
num_hidden_layers=8,
num_quantizers=16,
upsample_rates=(8, 5, 4, 3),
upsampling_ratios=(2, 2),
decoder_dim=1536,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.codebook_size = codebook_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.attention_bias = attention_bias
self.sliding_window = sliding_window
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.layer_scale_initial_scale = layer_scale_initial_scale
self.rms_norm_eps = rms_norm_eps
self.num_hidden_layers = num_hidden_layers
self.num_quantizers = num_quantizers
self.upsample_rates = upsample_rates
self.upsampling_ratios = upsampling_ratios
self.decoder_dim = decoder_dim
self.attention_dropout = attention_dropout
@property
def layer_types(self):
"""
All layer in code2wav should be sliding attention
"""
return ["sliding_attention"] * self.num_hidden_layers
class Qwen3OmniMoeConfig(PretrainedConfig):
model_type = "qwen3_omni_moe"
sub_configs = {
"thinker_config": Qwen3OmniMoeThinkerConfig,
"talker_config": Qwen3OmniMoeTalkerConfig,
"code2wav_config": Qwen3OmniMoeCode2WavConfig,
}
def __init__(
self,
thinker_config=None,
talker_config=None,
code2wav_config=None,
enable_audio_output=True,
im_start_token_id=151644,
im_end_token_id=151645,
tts_pad_token_id=151671,
tts_bos_token_id=151672,
tts_eos_token_id=151673,
system_token_id=8948,
user_token_id=872,
assistant_token_id=77091,
**kwargs,
):
super().__init__(**kwargs)
if thinker_config is None:
thinker_config = {}
logger.info(
"thinker_config is None. Initializing thinker model with default values"
)
if talker_config is None:
talker_config = {}
logger.info(
"talker_config is None. Initializing talker model with default values"
)
if code2wav_config is None:
code2wav_config = {}
logger.info(
"code2wav_config is None. Initializing code2wav model with default values"
)
self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)
self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)
self.enable_audio_output = enable_audio_output
self.im_start_token_id = im_start_token_id
self.im_end_token_id = im_end_token_id
self.tts_pad_token_id = tts_pad_token_id
self.tts_bos_token_id = tts_bos_token_id
self.tts_eos_token_id = tts_eos_token_id
self.system_token_id = system_token_id
self.user_token_id = user_token_id
self.assistant_token_id = assistant_token_id
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
from typing import Optional, Union
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
......@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
__all__ = [
"Qwen3VLMoeConfig",
"Qwen3VLMoeVisionConfig",
"Qwen3VLConfig",
"Qwen3VLVisionConfig",
]
......@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if model_type == "qwen3_omni_moe":
# For qwen3-omni
return MRotaryEmbedding.get_rope_index_qwen3_omni(
spatial_merge_size,
image_token_id,
video_token_id,
vision_start_token_id,
tokens_per_second,
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
**kwargs,
)
if (
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
) and video_grid_thw is not None:
......@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw, video_grid_thw[:, 0], dim=0
)
video_grid_thw[:, 0] = 1
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
......@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
elif model_type in (
"qwen2_vl",
"qwen3_vl",
"qwen3_vl_moe",
):
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
......@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
.flatten()
)
else:
raise RuntimeError("Unimplemented")
raise RuntimeError(f"Unimplemented model type: {model_type}")
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
......@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
@staticmethod
def get_rope_index_qwen3_omni(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
tokens_per_second: Optional[int] = None,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# For qwen3-omni
audio_token_id = kwargs["audio_token_id"]
audio_start_token_id = kwargs["audio_start_token_id"]
position_id_per_seconds = kwargs["position_id_per_seconds"]
use_audio_in_video = kwargs.get("use_audio_in_video", False)
audio_seqlens = kwargs.get("audio_seqlens", None)
second_per_grids = second_per_grid_ts
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=torch.float,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
for i, current_input_ids in enumerate(total_input_ids):
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(
current_input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = current_input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
audio_nums = torch.sum(current_input_ids == audio_start_token_id)
input_tokens = current_input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = (
image_nums,
video_nums,
audio_nums,
)
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
ed_vision_start = (
input_tokens.index(vision_start_token_id, st)
if (
(
image_token_id in input_tokens
or video_token_id in input_tokens
)
and (remain_videos > 0 or remain_images > 0)
)
else len(input_tokens) + 1
)
ed_audio_start = (
input_tokens.index(audio_start_token_id, st)
if (audio_token_id in input_tokens and remain_audios > 0)
else len(input_tokens) + 1
)
min_ed = min(ed_vision_start, ed_audio_start)
text_len = min_ed - st
if text_len != 0:
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += text_len
# Audio in Video
if (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
bos_len, eos_len = 2, 2
else:
bos_len, eos_len = 1, 1
llm_pos_ids_list.append(
torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += bos_len
# Audio Only
if min_ed == ed_audio_start:
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + audio_len + eos_len)
audio_idx += 1
remain_audios -= 1
# Image Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == image_token_id
):
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (
torch.arange(grid_t) * 1 * position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
image_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
image_len = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + image_len + eos_len)
image_idx += 1
remain_images -= 1
# Video Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == video_token_id
):
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + video_len + eos_len)
video_idx += 1
remain_videos -= 1
# Audio in Video
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
video_llm_pos_ids = (
MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
st += int(text_len + bos_len + audio_len + video_len + eos_len)
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
)
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(
[item.float() for item in llm_pos_ids_list], dim=1
).reshape(3, -1)
position_ids[..., i, :] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(current_input_ids)
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
s = input_ids.shape[1]
position_ids = torch.arange(s)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
@staticmethod
def get_rope_index_glm4v(
......@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
return position_ids, mrope_position_deltas
# For qwen3-omni
@staticmethod
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
# For qwen3-omni
@staticmethod
def _get_llm_pos_ids_for_vision(
st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
):
grid_h = grid_hs[vision_idx] // spatial_merge_size
grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (
torch.arange(grid_h, device=device)
.view(1, -1, 1)
.expand(len(t_index), -1, grid_w)
.flatten()
)
w_index = (
torch.arange(grid_w, device=device)
.view(1, 1, -1)
.expand(len(t_index), grid_h, -1)
.flatten()
)
t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
return llm_pos_ids
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
......
......@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids
......@@ -507,7 +506,7 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
use_deepstack: Dict[Modality, bool] = {},
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
......@@ -533,7 +532,9 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], []
# deepstack_embeddings: per-modality
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
......@@ -549,7 +550,8 @@ def embed_mm_inputs(
# "image", "video", etc
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
if len(items) != 0:
assert embedder is not None, f"no embedding method found for {modality}"
placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items],
device=input_ids.device,
......@@ -580,11 +582,12 @@ def embed_mm_inputs(
items_offset_list=items_offsets,
)
if use_deepstack and embedding is not None:
if use_deepstack.get(modality, None) and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
modalities += [modality]
embeddings += [embedding]
masks += [mask]
......@@ -597,17 +600,14 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
......@@ -616,14 +616,16 @@ def embed_mm_inputs(
other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
# 4. scatter embeddings into input embedding
for i, modality, embedding, mask in zip(
range(len(embeddings)), modalities, embeddings, masks
):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack:
if use_deepstack.get(modality, None):
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
......@@ -640,7 +642,7 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
use_deepstack: Dict[Modality, bool] = {},
**kwargs,
) -> torch.Tensor:
"""
......@@ -652,7 +654,7 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
use_deepstack: Whether to use deepstack embeddings for each modality, default False
**kwargs: Additional arguments passed to language model
Returns:
......
......@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if self.mm_processor and obj.contains_mm_input():
if not isinstance(obj.image_data, list) and obj.image_data:
if obj.image_data is not None and not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list) and obj.audio_data:
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data,
......
......@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
......
......@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type=Qwen3MoeDecoderLayer,
) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__(
config=config,
quant_config=quant_config,
prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer,
decoder_layer_type=decoder_layer_type,
alt_stream=alt_stream,
)
......
# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
from sglang.srt.configs.qwen3_omni import (
Qwen3OmniMoeAudioEncoderConfig,
Qwen3OmniMoeThinkerConfig,
Qwen3OmniMoeVisionEncoderConfig,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel
from sglang.srt.models.qwen3_vl_moe import (
Qwen3MoeLLMModel,
Qwen3VLMoeForConditionalGeneration,
load_fused_expert_weights,
)
from sglang.srt.utils import add_prefix, logger
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
def __init__(
self,
config: Qwen3OmniMoeAudioEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
embed_dim = config.d_model
self.embed_dim = config.d_model
self.self_attn = VisionAttention(
embed_dim=embed_dim,
num_heads=config.encoder_attention_heads,
projection_size=embed_dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend="fa3",
softmax_in_single_precision=False,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
x=hidden_states,
cu_seqlens=cu_seqlens,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
outputs = (hidden_states,)
return outputs
class SinusoidsPositionEmbedding(nn.Module):
def __init__(self, length, channels, max_timescale=10000):
super().__init__()
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2).float()
)
scaled_time = (
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
)
self.register_buffer(
"positional_embedding",
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
persistent=False,
)
def forward(self, seqlen: int):
return self.positional_embedding[:seqlen, :]
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
config: Qwen3OmniMoeAudioEncoderConfig
def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
super().__init__(config)
self.dropout = config.dropout
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.n_window = config.n_window
self.positional_embedding = SinusoidsPositionEmbedding(
self.max_source_positions, embed_dim
)
self.layers = nn.ModuleList(
[
Qwen3OmniMoeAudioEncoderLayer(config)
for _ in range(config.encoder_layers)
]
)
self.ln_post = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d2 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
3,
2,
padding=1,
)
self.conv2d3 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
3,
2,
padding=1,
)
self.conv_out = nn.Linear(
config.downsample_hidden_size
* ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
config.d_model,
bias=False,
)
self.proj1 = nn.Linear(config.d_model, config.d_model)
self.act = ACT2FN[config.activation_function]
self.proj2 = nn.Linear(config.d_model, config.output_dim)
self.n_window_infer = self.config.n_window_infer
self.conv_chunksize = self.config.conv_chunksize
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def forward(
self,
input_features,
feature_lens=None,
aftercnn_lens=None,
):
r"""
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length after cnn
"""
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
chunk_lengths = torch.tensor(
[self.n_window * 2] * chunk_num.sum(),
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(
chunk_list, batch_first=True
).transpose(1, 2)
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
[
torch.ones(length, dtype=torch.bool, device=padded_feature.device)
for length in feature_lens_after_cnn
],
batch_first=True,
)
padded_feature = padded_feature.unsqueeze(1)
# Split to chunk to avoid OOM during convolution
padded_embeds = []
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
padded_embed = F.gelu(self.conv2d1(chunk))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
padded_embeds.append(padded_embed)
padded_embed = torch.cat(padded_embeds, dim=0)
b, c, f, t = padded_embed.size()
padded_embed = self.conv_out(
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
)
positional_embedding = (
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
.unsqueeze(0)
.to(padded_embed.dtype)
)
padded_embed = padded_embed + positional_embedding
hidden_states = padded_embed[padded_mask_after_cnn]
cu_chunk_lens = [0]
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
self.n_window_infer // (self.n_window * 2)
)
for cnn_len in aftercnn_lens:
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
remainder = cnn_len % window_aftercnn
if remainder != 0:
cu_chunk_lens += [remainder]
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
-1, dtype=torch.int32
)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
cu_seqlens,
)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.proj1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.proj2(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
# Ignore copy
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
class Qwen3OmniMoeVisionPatchMerger(nn.Module):
def __init__(
self,
dim: int,
context_dim: int,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_postshuffle_norm=False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
self.ln_q = RMSNorm(
self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
)
self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
),
nn.GELU(),
RowParallelLinear(
self.hidden_size,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (
x.view(-1, self.hidden_size)
if self.use_postshuffle_norm
else x.view(-1, x.shape[-1])
)
hidden = self.ln_q(x).view(-1, self.hidden_size)
for layer in self.mlp:
if isinstance(hidden, tuple):
hidden = hidden[0]
hidden = layer(hidden)
if isinstance(hidden, tuple):
hidden = hidden[0]
return hidden
class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel):
config: Qwen3OmniMoeVisionEncoderConfig
def __init__(
self,
config: Qwen3OmniMoeVisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = None,
**kwargs,
):
super().__init__(
vision_config=config,
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
)
self.merger = Qwen3OmniMoeVisionPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
quant_config=quant_config,
use_postshuffle_norm=False,
prefix=add_prefix("merger", prefix),
)
self.merger_list = nn.ModuleList(
[
Qwen3OmniMoeVisionPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
use_postshuffle_norm=True,
quant_config=quant_config,
prefix=add_prefix("merger_list", prefix),
)
for _ in range(len(config.deepstack_visual_indexes))
]
)
del self.deepstack_merger_list
@property
def deepstack_merger_list(self):
return self.merger_list
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
config: Qwen3OmniMoeThinkerConfig
def __init__(
self,
config: Qwen3OmniMoeThinkerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(
config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
)
self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
self.visual = Qwen3OmniMoeVisionEncoder(
config.vision_config,
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix),
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
def get_audio_feature(self, items: List[MultimodalDataItem]):
feature_attention_mask = torch.cat(
[item.feature_attention_mask for item in items], dim=0
).type(torch.long)
input_features = (
torch.cat([item.feature for item in items])
.type(self.audio_tower.dtype)
.to(next(self.audio_tower.parameters()).device)
)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
input_features = input_features.permute(0, 2, 1)[
feature_attention_mask.bool()
].permute(1, 0)
else:
audio_feature_lengths = None
feature_lens = (
audio_feature_lengths
if audio_feature_lengths is not None
else feature_attention_mask.sum(-1)
)
audio_outputs = self.audio_tower(
input_features,
feature_lens=feature_lens,
)
audio_features = audio_outputs.last_hidden_state
return audio_features
class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel):
def __init__(
self,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config)
self.config = config
self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration(
config.thinker_config, quant_config=quant_config, prefix=prefix
)
self.enable_talker = False
self.pad_input_ids = self.thinker.pad_input_ids
self.forward = self.thinker.forward
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
"_input_scale",
)
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.num_experts
# Cache params_dict to avoid repeated expensive traversal of model parameters
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
name = name.replace(r"model.language_model.", r"model.")
if ("talker" in name or "code2wav" in name) and not self.enable_talker:
continue
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if "visual" in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
if "visual" in name or "audio_tower" in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
"w3",
num_experts,
)
else:
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
else:
# Skip loading extra parameters for GPTQ/modelopt models.
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# # other available replicas.
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
name = name_mapped
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name or "audio_tower" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
name = name.replace(r"model.visual.", r"visual.")
name = name.replace(r"attn.out_proj.", r"attn.proj.")
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(
f"Loaded weight with {name=} not found in params_dict"
)
EntryClass = Qwen3OmniMoeForConditionalGeneration
......@@ -15,7 +15,7 @@
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
from sglang.srt.configs.qwen3_vl import (
Qwen3VLConfig,
Qwen3VLTextConfig,
Qwen3VLVisionConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
......@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__)
# === Vision Encoder === #
......@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
return x
class Qwen3_VisionPatchMerger(nn.Module):
class Qwen3VLMoeVisionPatchMerger(nn.Module):
def __init__(
self,
......@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
return out
class Qwen3_VisionTransformer(nn.Module):
class Qwen3VLMoeVisionModel(nn.Module):
def __init__(
self,
......@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
# layer indexes of which layer's output should be deep-stacked
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
......@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
for layer_idx in range(vision_config.depth)
]
)
self.merger = Qwen3_VisionPatchMerger(
self.merger = Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
......@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
......@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
]
)
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1)
deepstack_feature_lists = []
......@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
language_model_cls=Qwen3LLMModel,
) -> None:
super().__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
self.visual = Qwen3VLMoeVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3LLMModel(
config=config,
# TODO: make it more elegant
if language_model_cls is Qwen3LLMModel:
self.config: Qwen3VLConfig = config # for qwen3-vl
else:
self.config = config.text_config # for qwen3-omni
self.model = language_model_cls(
config=self.config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.logits_processor = LogitsProcessor(self.config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
......@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
def separate_deepstack_embeds(self, embedding):
assert (
......
......@@ -14,29 +14,19 @@
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from functools import lru_cache
from typing import Iterable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import general_mm_embed_routine
......@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__)
......@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
config: Qwen3VLMoeTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
self.hidden_size = config.hidden_size
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
)
# process deepstack
if input_deepstack_embeds is not None and layer_idx in range(3):
if input_deepstack_embeds is not None and layer_idx < 3:
sep = self.hidden_size * layer_idx
hidden_states.add_(
input_deepstack_embeds[:, sep : sep + self.hidden_size]
......@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
return hidden_states, aux_hidden_states
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super(Qwen3VLForConditionalGeneration, self).__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3MoeLLMModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
def load_fused_expert_weights(
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
):
param = params_dict[name]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
else:
return self.pooler(hidden_states, forward_batch)
return True
def load_fused_expert_weights(
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self,
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
language_model_cls=Qwen3MoeLLMModel,
):
param = params_dict[name]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep
if ep_rank != ep_size - 1
else num_experts
)
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
return True
super().__init__(config, quant_config, prefix, language_model_cls)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
......@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
self.load_fused_expert_weights(
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
self.load_fused_expert_weights(
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
......@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
num_experts,
)
else:
self.load_fused_expert_weights(
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
......
......@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
self.transport_mode = transport_mode
......@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO,
"feature_attention_mask": Modality.AUDIO,
# Video-related attributes
"pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO,
......@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
if self._processor.__class__.__name__ in {
"Gemma3nProcessor",
"Qwen2AudioProcessor",
"Qwen3OmniMoeProcessor",
}:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
......
......@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
......@@ -209,22 +210,31 @@ async def preprocess_video(
return video
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
# Compatible with Qwen-VL & Qwen-Omni Series
class QwenVLImageProcessor(SGLangBaseProcessor):
models = [
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration,
Qwen3OmniMoeForConditionalGeneration,
]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.model_type = hf_config.model_type
if hf_config.model_type == "qwen3_omni_moe":
hf_config = hf_config.thinker_config
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id
self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None)
self.audio_token_id = getattr(hf_config, "audio_token_id", None)
self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
......@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.mm_tokens = MultimodalSpecialTokens(
image_token="<|vision_start|><|image_pad|><|vision_end|>",
image_token_id=hf_config.image_token_id,
# The regex that matches expanded image tokens.
image_token_regex=re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
),
video_token_id=hf_config.video_token_id,
audio_token_id=self.audio_token_id,
).build(_processor)
async def process_mm_data_async(
......@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens,
)
......@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
base_output, self.mm_tokens
)
audio_feature_lengths = None
if self.model_type == "qwen3_omni_moe":
audio_item = next((mm for mm in mm_items if mm.is_audio()), None)
if audio_item:
audio_feature_lengths = torch.sum(
audio_item.feature_attention_mask, dim=1
)
second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr(
ret, "video_second_per_grid", None
)
input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.mm_tokens.image_token_id,
video_token_id=self.mm_tokens.video_token_id,
vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type,
model_type=self.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=input_ids.unsqueeze(0),
image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=getattr(ret, "video_grid_thw", None),
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
second_per_grid_ts=second_per_grid_ts,
use_audio_in_video=False,
audio_seqlens=audio_feature_lengths,
audio_token_id=getattr(self.hf_config, "audio_token_id", None),
audio_start_token_id=self.audio_start_token_id,
position_id_per_seconds=getattr(
self.hf_config, "position_id_per_seconds", None
),
)
mrope_positions = mrope_positions.squeeze(1)
......@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.mm_tokens.video_token_id,
"audio_token_id": self.mm_tokens.audio_token_id,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
}
......@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
if __name__ == "__main__":
del (
TestOpenAIOmniServerBase,
TestOpenAIMLLMServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
OmniOpenAITestMixin,
)
unittest.main()
......@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
cls.base_url += "/v1"
class TestQwen3OmniServer(OmniOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ # workaround to fit into H100
"--trust-remote-code",
"--mem-fraction-static",
"0.90",
"--disable-cuda-graph",
"--disable-fast-image-processor",
"--grammar-backend",
"none",
],
)
cls.base_url += "/v1"
if __name__ == "__main__":
del (
TestOpenAIOmniServerBase,
TestOpenAIMLLMServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
OmniOpenAITestMixin,
)
unittest.main()
import base64
import io
import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import openai
......@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIOmniServerBase(CustomTestCase):
class TestOpenAIMLLMServerBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = ""
......@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase):
return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
class AudioOpenAITestMixin(TestOpenAIMLLMServerBase):
def verify_speech_recognition_response(self, text):
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text.lower()
), f"audio_response: |{text}| should contain |{check_word}|"
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
......@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
self.verify_speech_recognition_response(audio_response)
def test_audio_ambient_completion(self):
# bird song
......@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
def test_single_image_chat_completion(self):
class ImageOpenAITestMixin(TestOpenAIMLLMServerBase):
def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
{"role": "user", "content": content},
],
temperature=0,
**(self.get_vision_request_kwargs()),
......@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def verify_single_image_response(self, response):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
# `driver` is for gemma-3-it
assert (
"man" in text or "person" or "driver" in text
......@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging`
assert (
"iron" in text
or "hang" in text
or "cloth" in text
or "coat" in text
or "holding" in text
or "outfit" in text
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
), f"text: {text}, should contain iron, hang, cloth or holding"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
print("-" * 30)
print(f"Single image response:\n{response.choices[0].message.content}")
print("-" * 30)
self.verify_single_image_response(response)
def test_multi_turn_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
},
{
"type": "text",
"text": "I have two very different images. They are not related at all. "
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
"text": "I have two very different images. Please describe them.",
},
],
},
......@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def _test_mixed_image_audio_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "Please describe the image in one sentence, and then write down the audio transcription in English.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
print("-" * 30)
print(f"Mixed image & audio response:\n{text}")
print("-" * 30)
assert (
"man" in text
or "cab" in text
or "SUV" in text
or "taxi" in text
or "car" in text
), f"text: {text}, should contain man, cab, SUV, taxi or car"
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text
), f"text: |{text}| should contain |{check_word}|"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def prepare_video_images_messages(self, video_path):
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
# the size of the video embeds differs from the `modality` argument when preprocessed
......@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
self.assertGreater(len(video_response), 0)
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
class VideoOpenAITestMixin(TestOpenAIMLLMServerBase):
def prepare_video_messages(self, video_path):
messages = [
{
......@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
class OmniOpenAITestMixin(
ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin
):
def test_mixed_modality_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact",
},
],
},
]
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
)
text = response.choices[0].message.content
print("-" * 30)
print(f"Mixed modality response:\n{text}")
print("-" * 30)
self.verify_single_image_response(response=response)
self.verify_speech_recognition_response(text=text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment