Unverified Commit 0f7acdd7 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Model] Support Qwen3-VL Model Series (#24727)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarHuang Jie <92386084+JJJYmmm@users.noreply.github.com>
Co-authored-by: default avatar松灵 <26085463+wulipc@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 5801e497
...@@ -661,6 +661,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -661,6 +661,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
......
...@@ -1437,6 +1437,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ...@@ -1437,6 +1437,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
) )
# Qwen3-VL-Dense
def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen3-VL-4B-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompts = [
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Qwen3-VL-MOE
def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompts = [
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# R-4B # R-4B
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
...@@ -1645,6 +1719,8 @@ model_example_map = { ...@@ -1645,6 +1719,8 @@ model_example_map = {
"qwen2_vl": run_qwen2_vl, "qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni, "qwen2_5_omni": run_qwen2_5_omni,
"qwen3_vl": run_qwen3_vl,
"qwen3_vl_moe": run_qwen3_vl_moe,
"rvl": run_r_vl, "rvl": run_r_vl,
"skywork_chat": run_skyworkr1v, "skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm, "smolvlm": run_smolvlm,
...@@ -1658,6 +1734,8 @@ MODELS_NEED_VIDEO_METADATA = [ ...@@ -1658,6 +1734,8 @@ MODELS_NEED_VIDEO_METADATA = [
"glm4_1v", "glm4_1v",
"glm4_5v", "glm4_5v",
"glm4_5v_fp8", "glm4_5v_fp8",
"qwen3_vl",
"qwen3_vl_moe",
] ]
......
...@@ -31,6 +31,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: ...@@ -31,6 +31,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
""" """
# Ensure video metadata is included # Ensure video metadata is included
if "video" in mm_data: if "video" in mm_data:
# GLM4.1V doesn't support multiple videos
video = mm_data["video"] video = mm_data["video"]
num_frames = len(video) num_frames = len(video)
mm_data["video"] = (video, { mm_data["video"] = (video, {
...@@ -44,6 +45,34 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: ...@@ -44,6 +45,34 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
return mm_data return mm_data
def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"""
Patch the multimodal data for Qwen3-VL model.
"""
def create_metadata(frames: np.ndarray):
num_frames = len(frames)
return {
"total_num_frames": num_frames,
"fps": 2.0,
"duration": num_frames / 2.0,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
"do_sample_frames": True,
}
# Ensure video metadata is included
if "video" in mm_data:
video = mm_data["video"]
if isinstance(video, list):
# multiple videos
mm_data["video"] = [(vid, create_metadata(vid)) for vid in video]
else:
# single video
mm_data["video"] = (video, create_metadata(video))
return mm_data
def _test_processing_correctness( def _test_processing_correctness(
model_id_or_arch: str, model_id_or_arch: str,
hit_rate: float, hit_rate: float,
...@@ -182,8 +211,10 @@ _IGNORE_MM_KEYS = { ...@@ -182,8 +211,10 @@ _IGNORE_MM_KEYS = {
} }
MM_DATA_PATCHES = { MM_DATA_PATCHES = {
# GLM4.1V requires video metadata to be included in the input # GLM4.1V and Qwen3-VL requires video metadata to be included in the input
"glm4v": glm4_1v_patch_mm_data, "glm4v": glm4_1v_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
"qwen3_vl_moe": qwen3_vl_patch_mm_data,
} }
...@@ -326,6 +357,8 @@ def _test_processing_correctness_one( ...@@ -326,6 +357,8 @@ def _test_processing_correctness_one(
"Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-3B", "Qwen/Qwen2.5-Omni-3B",
"Qwen/Qwen3-VL-4B-Instruct",
"Qwen/Qwen3-VL-30B-A3B-Instruct",
"YannQi/R-4B", "YannQi/R-4B",
"Skywork/Skywork-R1V-38B", "Skywork/Skywork-R1V-38B",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct", "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
......
...@@ -557,6 +557,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -557,6 +557,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len=4096), max_model_len=4096),
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"), # noqa: E501
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True), trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
......
...@@ -103,6 +103,8 @@ def get_rope( ...@@ -103,6 +103,8 @@ def get_rope(
is_neox_style, is_neox_style,
dtype, dtype,
mrope_section=rope_scaling["mrope_section"], mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved",
False),
) )
else: else:
rotary_emb = RotaryEmbedding( rotary_emb = RotaryEmbedding(
......
...@@ -177,6 +177,18 @@ def triton_mrope( ...@@ -177,6 +177,18 @@ def triton_mrope(
return q, k return q, k
def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
return x_t
class MRotaryEmbedding(RotaryEmbedding): class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections.""" """Rotary Embedding with Multimodal Sections."""
...@@ -189,6 +201,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -189,6 +201,7 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
mrope_section: Optional[list[int]] = None, mrope_section: Optional[list[int]] = None,
mrope_interleaved: Optional[bool] = False,
) -> None: ) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of # In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get # the input video. We enlarge max_position_embeddings to 4 times to get
...@@ -198,6 +211,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -198,6 +211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
base, is_neox_style, dtype) base, is_neox_style, dtype)
self.mrope_section = mrope_section self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section: if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2 assert sum(self.mrope_section) == rotary_dim // 2
...@@ -225,17 +239,20 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -225,17 +239,20 @@ class MRotaryEmbedding(RotaryEmbedding):
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2: if positions.ndim == 2:
assert self.mrope_section assert self.mrope_section
if self.mrope_interleaved:
cos = torch.cat([ cos = apply_interleaved_rope(cos, self.mrope_section)
m[i] sin = apply_interleaved_rope(sin, self.mrope_section)
for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) else:
], cos = torch.cat([
dim=-1) m[i] for i, m in enumerate(
sin = torch.cat([ cos.split(self.mrope_section, dim=-1))
m[i] ],
for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) dim=-1)
], sin = torch.cat([
dim=-1) m[i] for i, m in enumerate(
sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
...@@ -265,6 +282,10 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -265,6 +282,10 @@ class MRotaryEmbedding(RotaryEmbedding):
assert positions.ndim == 1 or positions.ndim == 2 assert positions.ndim == 1 or positions.ndim == 2
assert key is not None assert key is not None
if self.mrope_interleaved:
# TODO: add triton implementation to support mrope-interleaved
return self.forward_native(positions, query, key)
num_tokens = positions.shape[-1] num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions] cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
...@@ -388,6 +409,15 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -388,6 +409,15 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len=context_len, context_len=context_len,
seq_len=seq_len, seq_len=seq_len,
) )
elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
return cls._qwen3vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
return cls._ernie_get_input_positions_tensor( return cls._ernie_get_input_positions_tensor(
input_tokens=input_tokens, input_tokens=input_tokens,
...@@ -526,6 +556,98 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -526,6 +556,98 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item() len(input_tokens)).item()
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
@classmethod
def _qwen3vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
for _ in range(t)]
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod @classmethod
def _ernie_get_input_positions_tensor( def _ernie_get_input_positions_tensor(
cls, cls,
......
...@@ -285,7 +285,7 @@ class Qwen2Model(nn.Module): ...@@ -285,7 +285,7 @@ class Qwen2Model(nn.Module):
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
......
...@@ -83,7 +83,7 @@ from .vision import get_vit_attn_backend ...@@ -83,7 +83,7 @@ from .vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 600
# === Vision Inputs === # # === Vision Inputs === #
......
...@@ -378,7 +378,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -378,7 +378,7 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen3_vl import (Qwen3VLProcessor,
Qwen3VLVideoProcessor)
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig, Qwen3VLVisionConfig)
from transformers.video_utils import VideoMetadata
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItem,
MultiModalKwargsItems, VideoItem)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .qwen2_5_vl import (Qwen2_5_VisionAttention,
Qwen2_5_VisionRotaryEmbedding,
Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs,
Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs)
from .qwen2_vl import Qwen2VLProcessingInfo
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
class Qwen3_VisionMLP(nn.Module):
def __init__(self,
in_features: int,
hidden_features: int,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.linear_fc1 = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc1")
self.linear_fc2 = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc2")
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return mlp_output
class Qwen3_VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
x = x + self.mlp(self.norm2(x))
return x
class Qwen3_VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
if self.use_postshuffle_norm:
context_dim = self.hidden_size
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.use_postshuffle_norm = use_postshuffle_norm
self.norm = norm_layer(
self.hidden_size if use_postshuffle_norm else context_dim)
self.linear_fc1 = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1")
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2")
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.num_position_embeddings = vision_config.num_position_embeddings
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
self.pos_embed = nn.Embedding(self.num_position_embeddings,
self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
for layer_idx in range(vision_config.depth)
])
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.deepstack_merger_list = nn.ModuleList([
Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}")
for layer_idx in range(len(self.deepstack_visual_indexes))
])
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def fast_pos_embed_interpolate(self, grid_thw):
num_grid_per_side = int(self.num_position_embeddings**0.5)
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in grid_thw:
h_idxs = torch.linspace(0,
num_grid_per_side - 1,
h,
dtype=torch.float32)
w_idxs = torch.linspace(0,
num_grid_per_side - 1,
w,
dtype=torch.float32)
h_idxs_floor = h_idxs.to(torch.long)
w_idxs_floor = w_idxs.to(torch.long)
h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1,
max=num_grid_per_side - 1)
w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1,
max=num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
w_idxs_floor[None]).flatten().tolist() * t)
idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
w_idxs_ceil[None]).flatten().tolist() * t)
idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
w_idxs_floor[None]).flatten().tolist() * t)
idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
w_idxs_ceil[None]).flatten().tolist() * t)
weight_list[0].extend(
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[1].extend(
((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
weight_list[2].extend(
(dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[3].extend(
(dh[None].T * dw[None]).flatten().tolist() * t)
device = self.pos_embed.weight.device
dtype = self.pos_embed.weight.dtype
p0 = self.pos_embed(
torch.tensor(
idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
weight_list[0], dtype=dtype, device=device)[:, None]
p1 = self.pos_embed(
torch.tensor(
idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
weight_list[1], dtype=dtype, device=device)[:, None]
p2 = self.pos_embed(
torch.tensor(
idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
weight_list[2], dtype=dtype, device=device)[:, None]
p3 = self.pos_embed(
torch.tensor(
idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
weight_list[3], dtype=dtype, device=device)[:, None]
patch_pos_embeds = p0 + p1 + p2 + p3
patch_pos_embeds = patch_pos_embeds.split(
[t * h * w for t, h, w in grid_thw])
patch_pos_embeds_permute = []
m_size = self.spatial_merge_size
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
m_size, -1).permute(0, 1, 3, 2, 4,
5).flatten(0, 4)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype
if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(
layer_num)
deepstack_feature = self.deepstack_merger_list[
deepstack_merger_idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists,
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen3VLConfig)
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
return self.ctx.get_hf_processor(
Qwen3VLProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def get_tokenizer(self):
return self.ctx.tokenizer
def get_image_processor(self,
**kwargs: object) -> Qwen2VLImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor
def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
return self.get_hf_processor(**kwargs).video_processor
def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessorFast],
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"],
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
padded_num_frames = num_frames + num_frames % temporal_patch_size
grid_t = max(padded_num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
video_fps: float, merge_size: int):
if not isinstance(indices, list):
indices = indices.tolist()
if len(indices) % merge_size != 0:
# don't update metadata's frames_indices directly
indices = indices + [indices[-1]
] * (merge_size - len(indices) % merge_size)
timestamps = [idx / video_fps for idx in indices]
timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2
for i in range(0, len(timestamps), merge_size)]
return timestamps
def _get_video_second_idx(
self,
metadata: dict[str, Any],
out_item: MultiModalKwargsItem,
do_sample_frames: Optional[bool] = None,
sampled_fps: Optional[float] = None) -> list[int]:
video_processor = self.get_video_processor()
merge_size = video_processor.merge_size
indices = metadata["frames_indices"]
# metadata["fps"] refers to the true fps of the input video.
video_fps = metadata["fps"]
if do_sample_frames is None:
do_sample_frames = metadata.get("do_sample_frames", False)
# If video frames are sampled in HF processor (instead of vLLM
# video loader), we need to re-calculate the indices from original
# metadata.
if do_sample_frames:
# here video_fps is the fps of the sampled video, and
# metadata["fps"] refers to the fps of the original video.
video_fps = sampled_fps if sampled_fps else video_processor.fps
total_num_frames = metadata["total_num_frames"]
num_frames = int(total_num_frames / metadata["fps"] * video_fps)
num_frames = min(
min(max(num_frames, video_processor.min_frames),
video_processor.max_frames), total_num_frames)
indices = np.linspace(0, total_num_frames - 1,
num_frames).round().astype(int).tolist()
timestamps = self._calculate_timestamps(indices, video_fps, merge_size)
return timestamps
class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
video_token = "<|vision_start|><|video_pad|><|vision_end|>"
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = (
self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts)
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
),
}
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[VideoItem]:
num_frames = max(num_frames, 2)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = []
for i in range(num_videos):
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
return video_items
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(video_needs_metadata=True)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
processor = self.info.get_hf_processor(**mm_kwargs)
# Separate video processing from image processing. Because the videos
# are processed into serval image patches
if ("videos" in mm_data and isinstance(mm_data["videos"], list)
and len(mm_data["videos"]) > 0):
video_grid_thw_lst = []
pixel_values_videos_lst = []
for item_idx, item in enumerate(mm_data.pop("videos", [])):
video_array, metadata = item
# NOTE: @JJJYmmm new attr metadata.frames_indices indicates
# the sampled frames indices of pre-sampled videos, which is
# used to calculate the timestamps. Make sure that
# do_sample_frames in mm_kwargs is false for presampled videos.
# NOTE: a copy of is created to update do_sample_frames,
# otherwise mm_hash for the object will be incorrect.
video_mm_kwargs = dict(**mm_kwargs)
if "do_sample_frames" not in video_mm_kwargs:
# qwen_vl_utils already has "do_sample_frames" in
# mm_kwargs, don't overwrite it.
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", False)
metadata = VideoMetadata(**{
k: metadata[k]
for k in metadata if k != "do_sample_frames"
})
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
video_outputs = super()._call_hf_processor(
prompt="<|vision_start|><|video_pad|><|vision_end|>",
mm_data=video_mm_data,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_ids = video_outputs.pop("input_ids")
video_placeholder = processor.tokenizer.batch_decode(
input_ids)[0]
prompt = prompt.replace(
"<|vision_start|><|video_pad|><|vision_end|>",
video_placeholder,
1,
)
video_grid_thw_lst.append(video_outputs["video_grid_thw"])
pixel_values_videos_lst.append(
video_outputs["pixel_values_videos"])
video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_grid_thw=torch.cat(video_grid_thw_lst),
)
else:
video_outputs = dict()
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
combined_outputs = dict(
processed_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
vision_end_token_id = hf_config.vision_end_token_id
merge_length = image_processor.merge_size**2
def get_image_replacement_qwen3vl(item_idx: int):
out_item = out_mm_kwargs["image"][item_idx]
grid_thw = out_item["image_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return [hf_processor.image_token_id] * num_tokens
def get_video_replacement_qwen3vl(item_idx: int):
out_item = out_mm_kwargs["video"][item_idx]
grid_thw = out_item["video_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
sampled_fps = hf_processor_mm_kwargs.get("fps")
if is_list_of(sampled_fps, float):
sampled_fps = sampled_fps[item_idx]
timestamps = self.info._get_video_second_idx(
metadata, out_item, do_sample_frames, sampled_fps)
assert len(timestamps) == grid_thw[0], (
f"The timestamps length({len(timestamps)}) should be equal "
f"video length ({grid_thw[0]}).")
frames_idx_token = [
tokenizer.encode(f"<{curr_time:.1f} seconds>",
add_special_tokens=False)
for curr_time in timestamps
]
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
placeholder = []
for frame_idx in frames_idx_token:
placeholder.extend(frame_idx)
placeholder.extend([vision_start_token_id] +
[video_token_id] * num_tokens_per_frame +
[vision_end_token_id])
return PromptUpdateDetails.select_token_id(placeholder,
video_token_id)
return [
PromptReplacement(
modality="image",
target=hf_processor.image_token,
replacement=get_image_replacement_qwen3vl,
),
# NOTE: We match string on purpose since searching sequence of
# token ids takes more time.
PromptReplacement(
modality="video",
target="<|vision_start|><|video_pad|><|vision_end|>",
replacement=get_video_replacement_qwen3vl,
),
]
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
# the same shape as input_embeds
"deepstack_input_embeds": 0
})
class Qwen3LLMModel(Qwen3Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
if not get_pp_group().is_first_rank:
assert self.start_layer >= len(
vllm_config.model_config.hf_config.vision_config.
deepstack_visual_indexes), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
# args for deepstack
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and \
layer_idx in range(0, len(deepstack_input_embeds)):
hidden_states = hidden_states + deepstack_input_embeds[
f"deepstack_input_embeds_{layer_idx}"]
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3ForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="lm_head")
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
})
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if modality.startswith("video"):
return "<|vision_start|><|video_pad|><|vision_end|>"
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
super().__init__()
config: Qwen3VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix,
"language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.use_deepstack = hasattr(config.vision_config,
'deepstack_visual_indexes')
self.deepstack_num_level = len(
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
def _get_deepstack_input_embeds(self,
num_tokens: int) -> IntermediateTensors:
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors({
f"deepstack_input_embeds_{idx}":
self.deepstack_input_embeds[idx][:num_tokens]
for idx in range(self.deepstack_num_level)
})
def _set_deepstack_input_embeds(
self, deepstack_input_embeds: torch.Tensor) -> None:
# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
self.deepstack_input_embeds = [
torch.zeros(num_tokens,
self.config.text_config.hidden_size,
device=self.deepstack_input_embeds[0].device,
dtype=self.deepstack_input_embeds[0].dtype)
for _ in range(self.deepstack_num_level)
]
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx])
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw)
def _process_image_input(
self,
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
merge_size = self.visual.spatial_merge_size
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
(merge_size * merge_size)).tolist()
return image_embeds.split(sizes)
def _process_video_input(
self,
video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
merge_size = self.visual.spatial_merge_size
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
(merge_size * merge_size)).tolist()
return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
for input_key in kwargs:
if input_key in ("pixel_values", "image_embeds"
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos", "video_embeds"
) and "video" not in mm_input_by_modality:
mm_input_by_modality[
"video"] = self._parse_and_validate_video_input(**kwargs)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if not mm_input_by_modality:
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def _compute_deepstack_embeds(
self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor:
visual_lens = [
x.shape[0] if isinstance(x, torch.Tensor) else len(x)
for x in multimodal_embeddings
]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
visual_dim = multimodal_embeddings_cat.shape[-1] // (
self.deepstack_num_level + 1)
main_dim, multi_dim = visual_dim, visual_dim * self.deepstack_num_level
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
multimodal_embeddings_cat, [main_dim, multi_dim],
dim=-1)
multimodal_embeddings = torch.split(multimodal_embeddings_main,
visual_lens,
dim=0)
multimodal_embeddings_multiscale = torch.split(
multimodal_embeddings_multiscale, visual_lens, dim=0)
deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))
deepstack_input_embeds = merge_multimodal_embeddings(
input_ids,
deepstack_input_embeds,
multimodal_embeddings_multiscale,
placeholder_token_id=[
self.config.image_token_id, self.config.video_token_id
],
)
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level,
visual_dim).contiguous()
deepstack_input_embeds = deepstack_input_embeds.permute(
1, 0, 2).contiguous()
return deepstack_input_embeds, multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
deepstack_input_embeds = None
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and self.use_deepstack:
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
input_ids, inputs_embeds, multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
if self.use_deepstack:
if deepstack_input_embeds is None:
deepstack_input_embeds = torch.zeros_like(
inputs_embeds).unsqueeze(0).repeat(
self.deepstack_num_level, 1, 1).contiguous()
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[Qwen2_5_VLImageInputs] = None,
video_input: Optional[Qwen2_5_VLVideoInputs] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if self.use_deepstack:
visual_dim = inputs_embeds.shape[-1]
deepstack_input_embeds = None
if image_input is not None or video_input is not None:
deepstack_input_embeds = torch.zeros_like(
inputs_embeds).unsqueeze(1).repeat(
1, self.deepstack_num_level, 1).flatten(1)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
if self.use_deepstack:
image_embeds = torch.cat(image_embeds)
image_embeds, image_embeds_multiscale = image_embeds.split(
[visual_dim, visual_dim * self.deepstack_num_level],
dim=-1)
deepstack_input_embeds = merge_multimodal_embeddings(
input_ids,
deepstack_input_embeds,
image_embeds_multiscale,
placeholder_token_id=self.config.image_token_id,
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
if self.use_deepstack:
video_embeds = torch.cat(video_embeds)
video_embeds, video_embeds_multiscale = video_embeds.split(
[visual_dim, visual_dim * self.deepstack_num_level],
dim=-1)
deepstack_input_embeds = merge_multimodal_embeddings(
input_ids,
deepstack_input_embeds,
video_embeds_multiscale,
placeholder_token_id=self.config.video_token_id,
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
if self.use_deepstack and deepstack_input_embeds is not None:
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level,
visual_dim).permute(1, 0, 2).contiguous()
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen3VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen3VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner from
# `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input)
input_ids = None
if self.use_deepstack and inputs_embeds is not None and get_pp_group(
).is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0))
else:
deepstack_input_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
# args for deepstack
deepstack_input_embeds=deepstack_input_embeds,
)
if inputs_embeds is not None and get_pp_group().is_first_rank:
self._clear_deepstack_input_embeds(inputs_embeds.size(0))
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="model.visual.merger",
tower_model="model.visual.",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
import typing
from collections.abc import Iterable
from typing import Callable, Optional, Union
import torch
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
Qwen3VLMoeConfig)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
Qwen3VLForConditionalGeneration,
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
from .utils import is_pp_missing_parameter, maybe_prefix
logger = init_logger(__name__)
class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen3VLMoeConfig)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
# the same shape as input_embeds
"deepstack_input_embeds": 0
})
class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
if not get_pp_group().is_first_rank:
assert self.start_layer >= len(
vllm_config.model_config.hf_config.vision_config.
deepstack_visual_indexes), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and \
layer_idx in range(0, len(deepstack_input_embeds)):
hidden_states = hidden_states + deepstack_input_embeds[
f"deepstack_input_embeds_{layer_idx}"]
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_fused_expert_weights(self, name: str, params_dict: dict,
loaded_weight: torch.Tensor, shard_id: str,
num_experts: int):
param = params_dict[name]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
success = weight_loader(param,
curr_expert_weight,
name,
shard_id,
expert_id,
return_success=True)
if not success:
return False
return True
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
".v_scale", "_v_scale", ".weight_scale",
"_weight_scale", ".input_scale", "_input_scale")
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.num_experts
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if ("experts.gate_up_proj" in name
or "experts.down_proj" in name):
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1,
-2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
success_w1 = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight[0],
"w1", num_experts)
success_w3 = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight[1],
"w3", num_experts)
success = success_w1 and success_w3
else:
# down_proj
success = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight,
shard_id, num_experts)
else:
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models
if name_mapped.endswith(
ignore_suffixes
) and name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3MoeForCausalLM, self).__init__()
self.config = vllm_config.model_config.hf_config.text_config
self.quant_config = vllm_config.quant_config
self.model = Qwen3MoeLLMModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3VLForConditionalGeneration, self).__init__()
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix,
"language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.use_deepstack = hasattr(config.vision_config,
'deepstack_visual_indexes')
self.deepstack_num_level = len(
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
...@@ -259,11 +259,13 @@ _MULTIMODAL_MODELS = { ...@@ -259,11 +259,13 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"), "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
"Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
# [Encoder-decoder] # [Encoder-decoder]
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
......
...@@ -156,7 +156,7 @@ class OpenCVVideoBackend(VideoLoader): ...@@ -156,7 +156,7 @@ class OpenCVVideoBackend(VideoLoader):
# can cause incorrect timestamp calculation without num_frames=-1. # can cause incorrect timestamp calculation without num_frames=-1.
metadata = { metadata = {
"total_num_frames": num_frames, "total_num_frames": num_frames,
"fps": original_fps, "fps": num_frames / duration,
"duration": duration, "duration": duration,
"video_backend": "opencv", "video_backend": "opencv",
"frames_indices": list(range(num_frames)), "frames_indices": list(range(num_frames)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment