Unverified Commit 86b04d25 authored by Mick's avatar Mick Committed by GitHub
Browse files

model: qwen3-omni (thinker-only) (#10911)


Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 85ebeecf
...@@ -853,6 +853,7 @@ multimodal_model_archs = [ ...@@ -853,6 +853,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"InternVLChatModel", "InternVLChatModel",
"InternS1ForConditionalGeneration", "InternS1ForConditionalGeneration",
......
This diff is collapsed.
from typing import Optional, Union
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation from transformers.modeling_rope_utils import rope_config_validation
...@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig): ...@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
self.vision_start_token_id = vision_start_token_id self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
__all__ = [
"Qwen3VLMoeConfig",
"Qwen3VLMoeVisionConfig",
"Qwen3VLConfig",
"Qwen3VLVisionConfig",
]
...@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[torch.Tensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if model_type == "qwen3_omni_moe":
# For qwen3-omni
return MRotaryEmbedding.get_rope_index_qwen3_omni(
spatial_merge_size,
image_token_id,
video_token_id,
vision_start_token_id,
tokens_per_second,
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
**kwargs,
)
if ( if (
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe") model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
) and video_grid_thw is not None: ) and video_grid_thw is not None:
...@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw, video_grid_thw[:, 0], dim=0 video_grid_thw, video_grid_thw[:, 0], dim=0
) )
video_grid_thw[:, 0] = 1 video_grid_thw[:, 0] = 1
mrope_position_deltas = [] mrope_position_deltas = []
if input_ids is not None and ( if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None image_grid_thw is not None or video_grid_thw is not None
...@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long() time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten() t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"): elif model_type in (
"qwen2_vl",
"qwen3_vl",
"qwen3_vl_moe",
):
t_index = ( t_index = (
torch.arange(llm_grid_t) torch.arange(llm_grid_t)
.view(-1, 1) .view(-1, 1)
...@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
.flatten() .flatten()
) )
else: else:
raise RuntimeError("Unimplemented") raise RuntimeError(f"Unimplemented model type: {model_type}")
h_index = ( h_index = (
torch.arange(llm_grid_h) torch.arange(llm_grid_h)
.view(1, -1, 1) .view(1, -1, 1)
...@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_deltas = max_position_ids + 1 - s mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
@staticmethod
def get_rope_index_qwen3_omni(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
tokens_per_second: Optional[int] = None,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# For qwen3-omni
audio_token_id = kwargs["audio_token_id"]
audio_start_token_id = kwargs["audio_start_token_id"]
position_id_per_seconds = kwargs["position_id_per_seconds"]
use_audio_in_video = kwargs.get("use_audio_in_video", False)
audio_seqlens = kwargs.get("audio_seqlens", None)
second_per_grids = second_per_grid_ts
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=torch.float,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
for i, current_input_ids in enumerate(total_input_ids):
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(
current_input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = current_input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
audio_nums = torch.sum(current_input_ids == audio_start_token_id)
input_tokens = current_input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = (
image_nums,
video_nums,
audio_nums,
)
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
ed_vision_start = (
input_tokens.index(vision_start_token_id, st)
if (
(
image_token_id in input_tokens
or video_token_id in input_tokens
)
and (remain_videos > 0 or remain_images > 0)
)
else len(input_tokens) + 1
)
ed_audio_start = (
input_tokens.index(audio_start_token_id, st)
if (audio_token_id in input_tokens and remain_audios > 0)
else len(input_tokens) + 1
)
min_ed = min(ed_vision_start, ed_audio_start)
text_len = min_ed - st
if text_len != 0:
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += text_len
# Audio in Video
if (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
bos_len, eos_len = 2, 2
else:
bos_len, eos_len = 1, 1
llm_pos_ids_list.append(
torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += bos_len
# Audio Only
if min_ed == ed_audio_start:
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + audio_len + eos_len)
audio_idx += 1
remain_audios -= 1
# Image Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == image_token_id
):
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (
torch.arange(grid_t) * 1 * position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
image_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
image_len = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + image_len + eos_len)
image_idx += 1
remain_images -= 1
# Video Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == video_token_id
):
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + video_len + eos_len)
video_idx += 1
remain_videos -= 1
# Audio in Video
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
video_llm_pos_ids = (
MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
st += int(text_len + bos_len + audio_len + video_len + eos_len)
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
)
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(
[item.float() for item in llm_pos_ids_list], dim=1
).reshape(3, -1)
position_ids[..., i, :] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(current_input_ids)
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
s = input_ids.shape[1]
position_ids = torch.arange(s)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120 # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
@staticmethod @staticmethod
def get_rope_index_glm4v( def get_rope_index_glm4v(
...@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
# For qwen3-omni
@staticmethod
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
# For qwen3-omni
@staticmethod
def _get_llm_pos_ids_for_vision(
st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
):
grid_h = grid_hs[vision_idx] // spatial_merge_size
grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (
torch.arange(grid_h, device=device)
.view(1, -1, 1)
.expand(len(t_index), -1, grid_w)
.flatten()
)
w_index = (
torch.arange(grid_w, device=device)
.view(1, 1, -1)
.expand(len(t_index), grid_h, -1)
.flatten()
)
t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
return llm_pos_ids
class DualChunkRotaryEmbedding(CustomOp): class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention.""" """Rotary positional embedding for Dual Chunk Attention."""
......
...@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa ...@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
input_ids_tensor[input_ids_tensor == token_id] = pad_value input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist() ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids return ret_input_ids
...@@ -507,7 +506,7 @@ def embed_mm_inputs( ...@@ -507,7 +506,7 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: dict[Modality, List[int]] = None, placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False, use_deepstack: Dict[Modality, bool] = {},
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Embed multimodal inputs and integrate them with text token embeddings. Embed multimodal inputs and integrate them with text token embeddings.
...@@ -533,7 +532,9 @@ def embed_mm_inputs( ...@@ -533,7 +532,9 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list: for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], [] # deepstack_embeddings: per-modality
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
# 2. Get multimodal embedding separately # 2. Get multimodal embedding separately
# Try get mm embedding if any # Try get mm embedding if any
for modality in Modality.all(): for modality in Modality.all():
...@@ -549,7 +550,8 @@ def embed_mm_inputs( ...@@ -549,7 +550,8 @@ def embed_mm_inputs(
# "image", "video", etc # "image", "video", etc
modality_id = modality.name.lower() modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None) embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None: if len(items) != 0:
assert embedder is not None, f"no embedding method found for {modality}"
placeholder_tensor = torch.as_tensor( placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items], [item.pad_value for item in items],
device=input_ids.device, device=input_ids.device,
...@@ -580,11 +582,12 @@ def embed_mm_inputs( ...@@ -580,11 +582,12 @@ def embed_mm_inputs(
items_offset_list=items_offsets, items_offset_list=items_offsets,
) )
if use_deepstack and embedding is not None: if use_deepstack.get(modality, None) and embedding is not None:
embedding, deepstack_embedding = ( embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding) multimodal_model.separate_deepstack_embeds(embedding)
) )
deepstack_embeddings += [deepstack_embedding] deepstack_embeddings += [deepstack_embedding]
modalities += [modality]
embeddings += [embedding] embeddings += [embedding]
masks += [mask] masks += [mask]
...@@ -597,17 +600,14 @@ def embed_mm_inputs( ...@@ -597,17 +600,14 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1) input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# deepstack embedding # deepstack embedding
if use_deepstack: if use_deepstack:
num_deepstack_embeddings = ( num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings, inputs_embeds.shape[-1] * num_deepstack_embeddings,
) )
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
input_deepstack_embeds = torch.zeros( input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape, deepstack_embedding_shape,
device=inputs_embeds.device, device=inputs_embeds.device,
...@@ -616,14 +616,16 @@ def embed_mm_inputs( ...@@ -616,14 +616,16 @@ def embed_mm_inputs(
other_info["input_deepstack_embeds"] = input_deepstack_embeds other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks): # 4. scatter embeddings into input embedding
for i, modality, embedding, mask in zip(
range(len(embeddings)), modalities, embeddings, masks
):
if embedding is None or mask is None: if embedding is None or mask is None:
continue continue
# in-place update # in-place update
indices = torch.where(mask.squeeze(dim=-1))[0] indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack.get(modality, None):
if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to( input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype inputs_embeds.device, inputs_embeds.dtype
) )
...@@ -640,7 +642,7 @@ def general_mm_embed_routine( ...@@ -640,7 +642,7 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False, use_deepstack: Dict[Modality, bool] = {},
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -652,7 +654,7 @@ def general_mm_embed_routine( ...@@ -652,7 +654,7 @@ def general_mm_embed_routine(
language_model: Base language model to use language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings use_deepstack: Whether to use deepstack embeddings for each modality, default False
**kwargs: Additional arguments passed to language model **kwargs: Additional arguments passed to language model
Returns: Returns:
......
...@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
if self.mm_processor and obj.contains_mm_input(): if self.mm_processor and obj.contains_mm_input():
if not isinstance(obj.image_data, list) and obj.image_data: if obj.image_data is not None and not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data] obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list) and obj.audio_data: if obj.audio_data is not None and not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data] obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async( mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
......
...@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
......
...@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel): ...@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
config: Qwen3MoeConfig, config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
decoder_layer_type=Qwen3MoeDecoderLayer,
) -> None: ) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__( super().__init__(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer, decoder_layer_type=decoder_layer_type,
alt_stream=alt_stream, alt_stream=alt_stream,
) )
......
This diff is collapsed.
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging import logging
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( ...@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionRotaryEmbedding,
) )
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig from sglang.srt.configs.qwen3_vl import (
Qwen3VLConfig,
Qwen3VLTextConfig,
Qwen3VLVisionConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import ( ...@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.models.qwen3 import Qwen3Model from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# === Vision Encoder === # # === Vision Encoder === #
...@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
return x return x
class Qwen3_VisionPatchMerger(nn.Module): class Qwen3VLMoeVisionPatchMerger(nn.Module):
def __init__( def __init__(
self, self,
...@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module): ...@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
return out return out
class Qwen3_VisionTransformer(nn.Module): class Qwen3VLMoeVisionModel(nn.Module):
def __init__( def __init__(
self, self,
...@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2 self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size self.temporal_patch_size = vision_config.temporal_patch_size
# layer indexes of which layer's output should be deep-stacked
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config) self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps) norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
...@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
] ]
) )
self.merger = Qwen3_VisionPatchMerger( self.merger = Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size, dim=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
norm_layer=norm_layer, norm_layer=norm_layer,
...@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
self.deepstack_merger_list = nn.ModuleList( self.deepstack_merger_list = nn.ModuleList(
[ [
Qwen3_VisionPatchMerger( Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size, dim=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size, spatial_merge_size=self.spatial_merge_size,
...@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
] ]
) )
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1) x = x.unsqueeze(1)
deepstack_feature_lists = [] deepstack_feature_lists = []
...@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module): ...@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
config: Qwen3VLConfig, config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
language_model_cls=Qwen3LLMModel,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.visual = Qwen3VLMoeVisionModel(
self.visual = Qwen3_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config, quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix), prefix=add_prefix("visual", prefix),
) )
self.model = Qwen3LLMModel( # TODO: make it more elegant
config=config, if language_model_cls is Qwen3LLMModel:
self.config: Qwen3VLConfig = config # for qwen3-vl
else:
self.config = config.text_config # for qwen3-omni
self.model = language_model_cls(
config=self.config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model", prefix), prefix=add_prefix("model", prefix),
) )
if config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, self.config.vocab_size,
config.hidden_size, self.config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(self.config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
...@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module): ...@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
# deepstack # deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def separate_deepstack_embeds(self, embedding): def separate_deepstack_embeds(self, embedding):
assert ( assert (
......
...@@ -14,29 +14,19 @@ ...@@ -14,29 +14,19 @@
# ============================================================================== # ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging import logging
from functools import lru_cache, partial from functools import lru_cache
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import Iterable, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import general_mm_embed_routine from sglang.srt.managers.mm_utils import general_mm_embed_routine
...@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem ...@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeModel from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import ( from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__( def __init__(
self, self,
*, *,
config: Qwen3VLMoeConfig, config: Qwen3VLMoeTextConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__(config=config, quant_config=quant_config, prefix=prefix) super().__init__(config=config, quant_config=quant_config, prefix=prefix)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens return self.embed_tokens
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
) )
# process deepstack # process deepstack
if input_deepstack_embeds is not None and layer_idx in range(3): if input_deepstack_embeds is not None and layer_idx < 3:
sep = self.hidden_size * layer_idx sep = self.hidden_size * layer_idx
hidden_states.add_( hidden_states.add_(
input_deepstack_embeds[:, sep : sep + self.hidden_size] input_deepstack_embeds[:, sep : sep + self.hidden_size]
...@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
return hidden_states, aux_hidden_states return hidden_states, aux_hidden_states
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): def load_fused_expert_weights(
def __init__( name: str,
self, params_dict: dict,
*, loaded_weight: torch.Tensor,
config: Qwen3VLMoeConfig, shard_id: str,
quant_config: Optional[QuantizationConfig] = None, num_experts: int,
prefix: str = "", ):
): param = params_dict[name]
super(Qwen3VLForConditionalGeneration, self).__init__() # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
self.config = config weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
self.visual = Qwen3_VisionTransformer( ep_size = get_moe_expert_parallel_world_size()
config.vision_config, if ep_size == 1:
norm_eps=getattr(config, "rms_norm_eps", 1e-6), for expert_id in range(num_experts):
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. curr_expert_weight = loaded_weight[expert_id]
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. weight_loader(
quant_config=quant_config, param,
prefix=add_prefix("visual", prefix), curr_expert_weight,
) name,
shard_id,
self.model = Qwen3MoeLLMModel( expert_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling else:
experts_per_ep = num_experts // ep_size
self.logits_processor = LogitsProcessor(config) start_expert = ep_rank * experts_per_ep
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) end_expert = (
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
) )
if not get_embedding: for idx, expert_id in enumerate(range(start_expert, end_expert)):
return self.logits_processor( curr_expert_weight = loaded_weight[expert_id]
input_ids, hidden_states, self.lm_head, forward_batch weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
) )
else: return True
return self.pooler(hidden_states, forward_batch)
def load_fused_expert_weights(
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self, self,
name: str, config: Qwen3VLMoeConfig,
params_dict: dict, quant_config: Optional[QuantizationConfig] = None,
loaded_weight: torch.Tensor, prefix: str = "",
shard_id: str, language_model_cls=Qwen3MoeLLMModel,
num_experts: int,
): ):
param = params_dict[name] super().__init__(config, quant_config, prefix, language_model_cls)
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep
if ep_rank != ep_size - 1
else num_experts
)
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
return True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): ...@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self._cached_params_dict = dict(self.named_parameters()) self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict params_dict = self._cached_params_dict
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "language_model" in name: name = name.replace(r"model.language_model.", r"model.")
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name: if "experts.gate_up_proj" in name or "experts.down_proj" in name:
...@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): ...@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
loaded_weight = loaded_weight.transpose(-1, -2) # no bias loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name: if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2) loaded_weight = loaded_weight.chunk(2, dim=-2)
self.load_fused_expert_weights( load_fused_expert_weights(
name_mapped, name_mapped,
params_dict, params_dict,
loaded_weight[0], loaded_weight[0],
"w1", "w1",
num_experts, num_experts,
) )
self.load_fused_expert_weights( load_fused_expert_weights(
name_mapped, name_mapped,
params_dict, params_dict,
loaded_weight[1], loaded_weight[1],
...@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): ...@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
num_experts, num_experts,
) )
else: else:
self.load_fused_expert_weights( load_fused_expert_weights(
name_mapped, name_mapped,
params_dict, params_dict,
loaded_weight, loaded_weight,
......
...@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
): ):
self.hf_config = hf_config self.hf_config = hf_config
self._processor = _processor self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args self.server_args = server_args
self.transport_mode = transport_mode self.transport_mode = transport_mode
...@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
"input_features": Modality.AUDIO, "input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO, "input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO, "audio_attention_mask": Modality.AUDIO,
"feature_attention_mask": Modality.AUDIO,
# Video-related attributes # Video-related attributes
"pixel_values_videos": Modality.VIDEO, "pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO, "second_per_grid_ts": Modality.VIDEO,
...@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
if self._processor.__class__.__name__ in { if self._processor.__class__.__name__ in {
"Gemma3nProcessor", "Gemma3nProcessor",
"Qwen2AudioProcessor", "Qwen2AudioProcessor",
"Qwen3OmniMoeProcessor",
}: }:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios kwargs["audio"] = audios
......
...@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode ...@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
...@@ -209,22 +210,31 @@ async def preprocess_video( ...@@ -209,22 +210,31 @@ async def preprocess_video(
return video return video
# Compatible with Qwen2VL and Qwen2_5VL # Compatible with Qwen-VL & Qwen-Omni Series
class Qwen2_5VLImageProcessor(SGLangBaseProcessor): class QwenVLImageProcessor(SGLangBaseProcessor):
models = [ models = [
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration, Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeForConditionalGeneration,
Qwen3OmniMoeForConditionalGeneration,
] ]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.model_type = hf_config.model_type
if hf_config.model_type == "qwen3_omni_moe":
hf_config = hf_config.thinker_config
super().__init__(hf_config, server_args, _processor, *args, **kwargs) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.vision_start_token_id = hf_config.vision_start_token_id self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id self.vision_end_token_id = hf_config.vision_end_token_id
self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None)
self.audio_token_id = getattr(hf_config, "audio_token_id", None)
self.NUM_TOKEN_PER_FRAME = 770 self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28 self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28 self.MIN_PIXELS = 4 * 28 * 28
...@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token="<|vision_start|><|image_pad|><|vision_end|>", image_token="<|vision_start|><|image_pad|><|vision_end|>",
image_token_id=hf_config.image_token_id, image_token_id=hf_config.image_token_id,
# The regex that matches expanded image tokens.
image_token_regex=re.compile( image_token_regex=re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
), ),
video_token_id=hf_config.video_token_id, video_token_id=hf_config.video_token_id,
audio_token_id=self.audio_token_id,
).build(_processor) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
...@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
video_data=request_obj.video_data, video_data=request_obj.video_data,
audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
) )
...@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
base_output, self.mm_tokens base_output, self.mm_tokens
) )
audio_feature_lengths = None
if self.model_type == "qwen3_omni_moe":
audio_item = next((mm for mm in mm_items if mm.is_audio()), None)
if audio_item:
audio_feature_lengths = torch.sum(
audio_item.feature_attention_mask, dim=1
)
second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr(
ret, "video_second_per_grid", None
)
input_ids = input_ids.flatten() input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.mm_tokens.image_token_id, image_token_id=self.mm_tokens.image_token_id,
video_token_id=self.mm_tokens.video_token_id, video_token_id=self.mm_tokens.video_token_id,
vision_start_token_id=self.vision_start_token_id, vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type, model_type=self.model_type,
tokens_per_second=getattr( tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None self.hf_config.vision_config, "tokens_per_second", None
), ),
input_ids=input_ids.unsqueeze(0), input_ids=input_ids.unsqueeze(0),
image_grid_thw=getattr(ret, "image_grid_thw", None), image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=getattr(ret, "video_grid_thw", None), video_grid_thw=getattr(ret, "video_grid_thw", None),
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None), second_per_grid_ts=second_per_grid_ts,
use_audio_in_video=False,
audio_seqlens=audio_feature_lengths,
audio_token_id=getattr(self.hf_config, "audio_token_id", None),
audio_start_token_id=self.audio_start_token_id,
position_id_per_seconds=getattr(
self.hf_config, "position_id_per_seconds", None
),
) )
mrope_positions = mrope_positions.squeeze(1) mrope_positions = mrope_positions.squeeze(1)
...@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.mm_tokens.image_token_id, "im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.mm_tokens.video_token_id, "video_token_id": self.mm_tokens.video_token_id,
"audio_token_id": self.mm_tokens.audio_token_id,
"mrope_positions": mrope_positions, "mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta, "mrope_position_delta": mrope_position_delta,
} }
...@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin): ...@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
if __name__ == "__main__": if __name__ == "__main__":
del ( del (
TestOpenAIOmniServerBase, TestOpenAIMLLMServerBase,
ImageOpenAITestMixin, ImageOpenAITestMixin,
VideoOpenAITestMixin, VideoOpenAITestMixin,
AudioOpenAITestMixin, AudioOpenAITestMixin,
OmniOpenAITestMixin,
) )
unittest.main() unittest.main()
...@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin): ...@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestQwen3OmniServer(OmniOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ # workaround to fit into H100
"--trust-remote-code",
"--mem-fraction-static",
"0.90",
"--disable-cuda-graph",
"--disable-fast-image-processor",
"--grammar-backend",
"none",
],
)
cls.base_url += "/v1"
if __name__ == "__main__": if __name__ == "__main__":
del ( del (
TestOpenAIOmniServerBase, TestOpenAIMLLMServerBase,
ImageOpenAITestMixin, ImageOpenAITestMixin,
VideoOpenAITestMixin, VideoOpenAITestMixin,
AudioOpenAITestMixin, AudioOpenAITestMixin,
OmniOpenAITestMixin,
) )
unittest.main() unittest.main()
import base64 import base64
import io import io
import os import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
import openai import openai
...@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test ...@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIOmniServerBase(CustomTestCase): class TestOpenAIMLLMServerBase(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "" cls.model = ""
...@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase): ...@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase):
return file_path return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase): class AudioOpenAITestMixin(TestOpenAIMLLMServerBase):
def verify_speech_recognition_response(self, text):
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text.lower()
), f"audio_response: |{text}| should contain |{check_word}|"
def prepare_audio_messages(self, prompt, audio_file_name): def prepare_audio_messages(self, prompt, audio_file_name):
messages = [ messages = [
{ {
...@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
"Listen to this audio and write down the audio transcription in English.", "Listen to this audio and write down the audio transcription in English.",
category="speech", category="speech",
) )
check_list = [ self.verify_speech_recognition_response(audio_response)
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
def test_audio_ambient_completion(self): def test_audio_ambient_completion(self):
# bird song # bird song
...@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
assert "bird" in audio_response assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase): class ImageOpenAITestMixin(TestOpenAIMLLMServerBase):
def test_single_image_chat_completion(self): def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create( response = client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
{ {"role": "user", "content": content},
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
], ],
temperature=0, temperature=0,
**(self.get_vision_request_kwargs()), **(self.get_vision_request_kwargs()),
...@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content text = response.choices[0].message.content
assert isinstance(text, str) assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def verify_single_image_response(self, response):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
# `driver` is for gemma-3-it # `driver` is for gemma-3-it
assert ( assert (
"man" in text or "person" or "driver" in text "man" in text or "person" or "driver" in text
...@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car" ), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging` # MiniCPMO fails to recognize `iron`, but `hanging`
assert ( assert (
"iron" in text "iron" in text or "hang" in text or "cloth" in text or "holding" in text
or "hang" in text ), f"text: {text}, should contain iron, hang, cloth or holding"
or "cloth" in text
or "coat" in text
or "holding" in text
or "outfit" in text
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
print("-" * 30)
print(f"Single image response:\n{response.choices[0].message.content}")
print("-" * 30)
self.verify_single_image_response(response)
def test_multi_turn_chat_completion(self): def test_multi_turn_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
...@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
}, },
{ {
"type": "text", "type": "text",
"text": "I have two very different images. They are not related at all. " "text": "I have two very different images. Please describe them.",
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
}, },
], ],
}, },
...@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def _test_mixed_image_audio_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "Please describe the image in one sentence, and then write down the audio transcription in English.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
print("-" * 30)
print(f"Mixed image & audio response:\n{text}")
print("-" * 30)
assert (
"man" in text
or "cab" in text
or "SUV" in text
or "taxi" in text
or "car" in text
), f"text: {text}, should contain man, cab, SUV, taxi or car"
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text
), f"text: |{text}| should contain |{check_word}|"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def prepare_video_images_messages(self, video_path): def prepare_video_images_messages(self, video_path):
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
# the size of the video embeds differs from the `modality` argument when preprocessed # the size of the video embeds differs from the `modality` argument when preprocessed
...@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
class VideoOpenAITestMixin(TestOpenAIOmniServerBase): class VideoOpenAITestMixin(TestOpenAIMLLMServerBase):
def prepare_video_messages(self, video_path): def prepare_video_messages(self, video_path):
messages = [ messages = [
{ {
...@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase): ...@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
), f"video_response: {video_response}, should contain 'black' or 'dark'" ), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response) self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
class OmniOpenAITestMixin(
ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin
):
def test_mixed_modality_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact",
},
],
},
]
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
)
text = response.choices[0].message.content
print("-" * 30)
print(f"Mixed modality response:\n{text}")
print("-" * 30)
self.verify_single_image_response(response=response)
self.verify_speech_recognition_response(text=text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment