Unverified Commit 86b04d25 authored by Mick's avatar Mick Committed by GitHub
Browse files

model: qwen3-omni (thinker-only) (#10911)


Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 85ebeecf
......@@ -853,6 +853,7 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
......
This diff is collapsed.
from typing import Optional, Union
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
......@@ -576,11 +574,3 @@ class Qwen3VLMoeConfig(PretrainedConfig):
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
__all__ = [
"Qwen3VLMoeConfig",
"Qwen3VLMoeVisionConfig",
"Qwen3VLConfig",
"Qwen3VLVisionConfig",
]
......@@ -1156,6 +1156,20 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if model_type == "qwen3_omni_moe":
# For qwen3-omni
return MRotaryEmbedding.get_rope_index_qwen3_omni(
spatial_merge_size,
image_token_id,
video_token_id,
vision_start_token_id,
tokens_per_second,
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
**kwargs,
)
if (
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
) and video_grid_thw is not None:
......@@ -1163,6 +1177,7 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw, video_grid_thw[:, 0], dim=0
)
video_grid_thw[:, 0] = 1
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
......@@ -1248,7 +1263,11 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
elif model_type in (
"qwen2_vl",
"qwen3_vl",
"qwen3_vl_moe",
):
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
......@@ -1256,7 +1275,7 @@ class MRotaryEmbedding(RotaryEmbedding):
.flatten()
)
else:
raise RuntimeError("Unimplemented")
raise RuntimeError(f"Unimplemented model type: {model_type}")
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
......@@ -1306,6 +1325,304 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
@staticmethod
def get_rope_index_qwen3_omni(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
tokens_per_second: Optional[int] = None,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# For qwen3-omni
audio_token_id = kwargs["audio_token_id"]
audio_start_token_id = kwargs["audio_start_token_id"]
position_id_per_seconds = kwargs["position_id_per_seconds"]
use_audio_in_video = kwargs.get("use_audio_in_video", False)
audio_seqlens = kwargs.get("audio_seqlens", None)
second_per_grids = second_per_grid_ts
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=torch.float,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
for i, current_input_ids in enumerate(total_input_ids):
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(
current_input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = current_input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
audio_nums = torch.sum(current_input_ids == audio_start_token_id)
input_tokens = current_input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = (
image_nums,
video_nums,
audio_nums,
)
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
ed_vision_start = (
input_tokens.index(vision_start_token_id, st)
if (
(
image_token_id in input_tokens
or video_token_id in input_tokens
)
and (remain_videos > 0 or remain_images > 0)
)
else len(input_tokens) + 1
)
ed_audio_start = (
input_tokens.index(audio_start_token_id, st)
if (audio_token_id in input_tokens and remain_audios > 0)
else len(input_tokens) + 1
)
min_ed = min(ed_vision_start, ed_audio_start)
text_len = min_ed - st
if text_len != 0:
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += text_len
# Audio in Video
if (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
bos_len, eos_len = 2, 2
else:
bos_len, eos_len = 1, 1
llm_pos_ids_list.append(
torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx += bos_len
# Audio Only
if min_ed == ed_audio_start:
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + audio_len + eos_len)
audio_idx += 1
remain_audios -= 1
# Image Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == image_token_id
):
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (
torch.arange(grid_t) * 1 * position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
image_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
image_len = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + image_len + eos_len)
image_idx += 1
remain_images -= 1
# Video Only
elif (
min_ed == ed_vision_start
and current_input_ids[ed_vision_start + 1] == video_token_id
):
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
llm_pos_ids_list.append(llm_pos_ids)
st += int(text_len + bos_len + video_len + eos_len)
video_idx += 1
remain_videos -= 1
# Audio in Video
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
):
audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
audio_seqlens[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grids[video_idx].cpu().float()
* position_id_per_seconds
).float()
video_llm_pos_ids = (
MRotaryEmbedding._get_llm_pos_ids_for_vision(
st_idx,
video_idx,
spatial_merge_size,
t_index,
grid_hs,
grid_ws,
input_ids.device,
)
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
st += int(text_len + bos_len + audio_len + video_len + eos_len)
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
)
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(
[item.float() for item in llm_pos_ids_list], dim=1
).reshape(3, -1)
position_ids[..., i, :] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(current_input_ids)
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
s = input_ids.shape[1]
position_ids = torch.arange(s)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
@staticmethod
def get_rope_index_glm4v(
......@@ -1504,6 +1821,44 @@ class MRotaryEmbedding(RotaryEmbedding):
return position_ids, mrope_position_deltas
# For qwen3-omni
@staticmethod
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
# For qwen3-omni
@staticmethod
def _get_llm_pos_ids_for_vision(
st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
):
grid_h = grid_hs[vision_idx] // spatial_merge_size
grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (
torch.arange(grid_h, device=device)
.view(1, -1, 1)
.expand(len(t_index), -1, grid_w)
.flatten()
)
w_index = (
torch.arange(grid_w, device=device)
.view(1, 1, -1)
.expand(len(t_index), grid_h, -1)
.flatten()
)
t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
return llm_pos_ids
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
......
......@@ -280,7 +280,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids
......@@ -507,7 +506,7 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
use_deepstack: Dict[Modality, bool] = {},
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
......@@ -533,7 +532,9 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], []
# deepstack_embeddings: per-modality
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
......@@ -549,7 +550,8 @@ def embed_mm_inputs(
# "image", "video", etc
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
if len(items) != 0:
assert embedder is not None, f"no embedding method found for {modality}"
placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items],
device=input_ids.device,
......@@ -580,11 +582,12 @@ def embed_mm_inputs(
items_offset_list=items_offsets,
)
if use_deepstack and embedding is not None:
if use_deepstack.get(modality, None) and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
modalities += [modality]
embeddings += [embedding]
masks += [mask]
......@@ -597,17 +600,14 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
......@@ -616,14 +616,16 @@ def embed_mm_inputs(
other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
# 4. scatter embeddings into input embedding
for i, modality, embedding, mask in zip(
range(len(embeddings)), modalities, embeddings, masks
):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack:
if use_deepstack.get(modality, None):
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
......@@ -640,7 +642,7 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
use_deepstack: Dict[Modality, bool] = {},
**kwargs,
) -> torch.Tensor:
"""
......@@ -652,7 +654,7 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
use_deepstack: Whether to use deepstack embeddings for each modality, default False
**kwargs: Additional arguments passed to language model
Returns:
......
......@@ -587,9 +587,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if self.mm_processor and obj.contains_mm_input():
if not isinstance(obj.image_data, list) and obj.image_data:
if obj.image_data is not None and not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list) and obj.audio_data:
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data,
......
......@@ -518,6 +518,7 @@ class Qwen2MoeModel(nn.Module):
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
......
......@@ -661,13 +661,14 @@ class Qwen3MoeModel(Qwen2MoeModel):
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type=Qwen3MoeDecoderLayer,
) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__(
config=config,
quant_config=quant_config,
prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer,
decoder_layer_type=decoder_layer_type,
alt_stream=alt_stream,
)
......
This diff is collapsed.
......@@ -15,7 +15,7 @@
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -27,7 +27,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
from sglang.srt.configs.qwen3_vl import (
Qwen3VLConfig,
Qwen3VLTextConfig,
Qwen3VLVisionConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
......@@ -38,16 +42,24 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__)
# === Vision Encoder === #
......@@ -196,7 +208,7 @@ class Qwen3_VisionBlock(nn.Module):
return x
class Qwen3_VisionPatchMerger(nn.Module):
class Qwen3VLMoeVisionPatchMerger(nn.Module):
def __init__(
self,
......@@ -246,7 +258,7 @@ class Qwen3_VisionPatchMerger(nn.Module):
return out
class Qwen3_VisionTransformer(nn.Module):
class Qwen3VLMoeVisionModel(nn.Module):
def __init__(
self,
......@@ -263,10 +275,10 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
# layer indexes of which layer's output should be deep-stacked
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
......@@ -286,7 +298,7 @@ class Qwen3_VisionTransformer(nn.Module):
for layer_idx in range(vision_config.depth)
]
)
self.merger = Qwen3_VisionPatchMerger(
self.merger = Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
......@@ -297,7 +309,7 @@ class Qwen3_VisionTransformer(nn.Module):
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
Qwen3VLMoeVisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
......@@ -462,7 +474,6 @@ class Qwen3_VisionTransformer(nn.Module):
]
)
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1)
deepstack_feature_lists = []
......@@ -604,37 +615,43 @@ class Qwen3VLForConditionalGeneration(nn.Module):
config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
language_model_cls=Qwen3LLMModel,
) -> None:
super().__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
self.visual = Qwen3VLMoeVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3LLMModel(
config=config,
# TODO: make it more elegant
if language_model_cls is Qwen3LLMModel:
self.config: Qwen3VLConfig = config # for qwen3-vl
else:
self.config = config.text_config # for qwen3-omni
self.model = language_model_cls(
config=self.config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.logits_processor = LogitsProcessor(self.config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
......@@ -642,10 +659,7 @@ class Qwen3VLForConditionalGeneration(nn.Module):
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True}
def separate_deepstack_embeds(self, embedding):
assert (
......
......@@ -14,29 +14,19 @@
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from functools import lru_cache
from typing import Iterable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import general_mm_embed_routine
......@@ -44,11 +34,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.utils.hf_transformers_utils import get_processor
logger = logging.getLogger(__name__)
......@@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
config: Qwen3VLMoeTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
self.hidden_size = config.hidden_size
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -120,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
)
# process deepstack
if input_deepstack_embeds is not None and layer_idx in range(3):
if input_deepstack_embeds is not None and layer_idx < 3:
sep = self.hidden_size * layer_idx
hidden_states.add_(
input_deepstack_embeds[:, sep : sep + self.hidden_size]
......@@ -146,144 +120,56 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
return hidden_states, aux_hidden_states
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super(Qwen3VLForConditionalGeneration, self).__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3MoeLLMModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
def load_fused_expert_weights(
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
):
param = params_dict[name]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
else:
return self.pooler(hidden_states, forward_batch)
return True
def load_fused_expert_weights(
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self,
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
language_model_cls=Qwen3MoeLLMModel,
):
param = params_dict[name]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep
if ep_rank != ep_size - 1
else num_experts
)
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
return True
super().__init__(config, quant_config, prefix, language_model_cls)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......@@ -329,8 +215,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
......@@ -384,14 +269,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
self.load_fused_expert_weights(
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
self.load_fused_expert_weights(
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
......@@ -399,7 +284,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
num_experts,
)
else:
self.load_fused_expert_weights(
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
......
......@@ -155,7 +155,6 @@ class BaseMultimodalProcessor(ABC):
):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
self.transport_mode = transport_mode
......@@ -191,6 +190,7 @@ class BaseMultimodalProcessor(ABC):
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO,
"feature_attention_mask": Modality.AUDIO,
# Video-related attributes
"pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO,
......@@ -222,6 +222,7 @@ class BaseMultimodalProcessor(ABC):
if self._processor.__class__.__name__ in {
"Gemma3nProcessor",
"Qwen2AudioProcessor",
"Qwen3OmniMoeProcessor",
}:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
......
......@@ -12,6 +12,7 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
......@@ -209,22 +210,31 @@ async def preprocess_video(
return video
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
# Compatible with Qwen-VL & Qwen-Omni Series
class QwenVLImageProcessor(SGLangBaseProcessor):
models = [
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration,
Qwen3OmniMoeForConditionalGeneration,
]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.model_type = hf_config.model_type
if hf_config.model_type == "qwen3_omni_moe":
hf_config = hf_config.thinker_config
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id
self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None)
self.audio_token_id = getattr(hf_config, "audio_token_id", None)
self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
......@@ -233,10 +243,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.mm_tokens = MultimodalSpecialTokens(
image_token="<|vision_start|><|image_pad|><|vision_end|>",
image_token_id=hf_config.image_token_id,
# The regex that matches expanded image tokens.
image_token_regex=re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
),
video_token_id=hf_config.video_token_id,
audio_token_id=self.audio_token_id,
).build(_processor)
async def process_mm_data_async(
......@@ -247,11 +259,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens,
)
......@@ -269,20 +281,41 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
base_output, self.mm_tokens
)
audio_feature_lengths = None
if self.model_type == "qwen3_omni_moe":
audio_item = next((mm for mm in mm_items if mm.is_audio()), None)
if audio_item:
audio_feature_lengths = torch.sum(
audio_item.feature_attention_mask, dim=1
)
second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr(
ret, "video_second_per_grid", None
)
input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.mm_tokens.image_token_id,
video_token_id=self.mm_tokens.video_token_id,
vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type,
model_type=self.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=input_ids.unsqueeze(0),
image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=getattr(ret, "video_grid_thw", None),
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
second_per_grid_ts=second_per_grid_ts,
use_audio_in_video=False,
audio_seqlens=audio_feature_lengths,
audio_token_id=getattr(self.hf_config, "audio_token_id", None),
audio_start_token_id=self.audio_start_token_id,
position_id_per_seconds=getattr(
self.hf_config, "position_id_per_seconds", None
),
)
mrope_positions = mrope_positions.squeeze(1)
......@@ -293,6 +326,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.mm_tokens.video_token_id,
"audio_token_id": self.mm_tokens.audio_token_id,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
}
......@@ -355,9 +355,10 @@ class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
if __name__ == "__main__":
del (
TestOpenAIOmniServerBase,
TestOpenAIMLLMServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
OmniOpenAITestMixin,
)
unittest.main()
......@@ -241,11 +241,35 @@ class TestGLM41VServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
cls.base_url += "/v1"
class TestQwen3OmniServer(OmniOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ # workaround to fit into H100
"--trust-remote-code",
"--mem-fraction-static",
"0.90",
"--disable-cuda-graph",
"--disable-fast-image-processor",
"--grammar-backend",
"none",
],
)
cls.base_url += "/v1"
if __name__ == "__main__":
del (
TestOpenAIOmniServerBase,
TestOpenAIMLLMServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
OmniOpenAITestMixin,
)
unittest.main()
import base64
import io
import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import openai
......@@ -22,7 +23,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIOmniServerBase(CustomTestCase):
class TestOpenAIMLLMServerBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = ""
......@@ -58,7 +59,20 @@ class TestOpenAIOmniServerBase(CustomTestCase):
return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
class AudioOpenAITestMixin(TestOpenAIMLLMServerBase):
def verify_speech_recognition_response(self, text):
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text.lower()
), f"audio_response: |{text}| should contain |{check_word}|"
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
......@@ -116,17 +130,7 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
self.verify_speech_recognition_response(audio_response)
def test_audio_ambient_completion(self):
# bird song
......@@ -138,26 +142,39 @@ class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
def test_single_image_chat_completion(self):
class ImageOpenAITestMixin(TestOpenAIMLLMServerBase):
def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
{"role": "user", "content": content},
],
temperature=0,
**(self.get_vision_request_kwargs()),
......@@ -166,6 +183,17 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def verify_single_image_response(self, response):
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
# `driver` is for gemma-3-it
assert (
"man" in text or "person" or "driver" in text
......@@ -179,19 +207,44 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging`
assert (
"iron" in text
or "hang" in text
or "cloth" in text
or "coat" in text
or "holding" in text
or "outfit" in text
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
), f"text: {text}, should contain iron, hang, cloth or holding"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in a sentence.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
print("-" * 30)
print(f"Single image response:\n{response.choices[0].message.content}")
print("-" * 30)
self.verify_single_image_response(response)
def test_multi_turn_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......@@ -264,8 +317,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
},
{
"type": "text",
"text": "I have two very different images. They are not related at all. "
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
"text": "I have two very different images. Please describe them.",
},
],
},
......@@ -296,64 +348,6 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def _test_mixed_image_audio_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "Please describe the image in one sentence, and then write down the audio transcription in English.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
print("-" * 30)
print(f"Mixed image & audio response:\n{text}")
print("-" * 30)
assert (
"man" in text
or "cab" in text
or "SUV" in text
or "taxi" in text
or "car" in text
), f"text: {text}, should contain man, cab, SUV, taxi or car"
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in text
), f"text: |{text}| should contain |{check_word}|"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def prepare_video_images_messages(self, video_path):
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
# the size of the video embeds differs from the `modality` argument when preprocessed
......@@ -461,7 +455,7 @@ class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
self.assertGreater(len(video_response), 0)
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
class VideoOpenAITestMixin(TestOpenAIMLLMServerBase):
def prepare_video_messages(self, video_path):
messages = [
{
......@@ -526,3 +520,45 @@ class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
class OmniOpenAITestMixin(
ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin
):
def test_mixed_modality_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "audio_url",
"audio_url": {"url": AUDIO_TRUMP_SPEECH_URL},
},
{
"type": "text",
"text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact",
},
],
},
]
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
)
text = response.choices[0].message.content
print("-" * 30)
print(f"Mixed modality response:\n{text}")
print("-" * 30)
self.verify_single_image_response(response=response)
self.verify_speech_recognition_response(text=text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment