Commit b12c902b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'cx/v0.11.0-dev' into v0.11.0-dev-omni

parents c16e075a f39afa4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Optional, Union
import numpy as np
import torch
from transformers import PretrainedConfig
from vllm.triton_utils import tl, triton
......@@ -62,10 +55,8 @@ def _triton_mrope_forward(
# Updated offsets for half head_dim
cos_offsets = tl.arange(0, pad_hd // 2)
if is_interleaved:
h_mask = (((cos_offsets % 3) == 1) &
(cos_offsets <= 3 * mrope_section_h))
w_mask = (((cos_offsets % 3) == 2) &
(cos_offsets <= 3 * mrope_section_w))
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
t_mask = ~(h_mask | w_mask)
else:
t_end = mrope_section_t
......@@ -89,21 +80,25 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(
0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
0, pad_hd // 2)[None, :] < rd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
0, pad_hd // 2)[None, :] < rd // 2)
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_half_k_offsets = (
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
mask=first_q_mask,
other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets,
mask=first_k_mask,
other=0).to(sin_row.dtype)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype
)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2)
......@@ -111,12 +106,12 @@ def _triton_mrope_forward(
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets,
mask=second_q_mask,
other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets,
mask=second_k_mask,
other=0).to(sin_row.dtype)
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
sin_row.dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
......@@ -168,7 +163,7 @@ def triton_mrope(
cos = cos.contiguous()
sin = sin.contiguous()
_triton_mrope_forward[(n_row, )](
_triton_mrope_forward[(n_row,)](
q,
k,
cos,
......@@ -189,15 +184,14 @@ def triton_mrope(
return q, k
def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
return x_t
......@@ -212,17 +206,16 @@ class MRotaryEmbedding(RotaryEmbedding):
base: float,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
mrope_section: list[int] | None = None,
mrope_interleaved: bool = False,
# YaRN parameters.
*,
scaling_factor: Optional[float] = None,
scaling_factor: float | None = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
......@@ -230,8 +223,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.beta_slow = beta_slow
if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
yarn_get_mscale(self.scaling_factor) * attn_factor)
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
else:
self.mscale = 1.0
......@@ -239,8 +231,14 @@ class MRotaryEmbedding(RotaryEmbedding):
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
base, is_neox_style, dtype)
super().__init__(
head_size,
rotary_dim,
self.cache_max_position_num,
base,
is_neox_style,
dtype,
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
......@@ -261,9 +259,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
......@@ -286,31 +284,27 @@ class MRotaryEmbedding(RotaryEmbedding):
cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section)
else:
cos = torch.cat([
m[i] for i, m in enumerate(
cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i] for i, m in enumerate(
sin.split(self.mrope_section, dim=-1))
],
dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
self.is_neox_style)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
self.is_neox_style)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......@@ -318,10 +312,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
......@@ -348,17 +341,15 @@ class MRotaryEmbedding(RotaryEmbedding):
return q.reshape(query_shape), k.reshape(key_shape)
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
self.is_neox_style)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
self.is_neox_style)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......@@ -366,885 +357,20 @@ class MRotaryEmbedding(RotaryEmbedding):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
@classmethod
def get_input_positions(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]:
"""Get mrope input positions and delta value."""
image_grid_thw = [] if image_grid_thw is None else image_grid_thw
video_grid_thw = [] if video_grid_thw is None else video_grid_thw
second_per_grid_ts = [] if second_per_grid_ts is None else \
second_per_grid_ts
llm_positions, mrope_position_delta = \
cls.get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
return llm_positions.tolist(), mrope_position_delta
@classmethod
def get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni":
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
elif hf_config.model_type in ["glm4v", "glm4v_moe"]:
return cls._glm4v_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
return cls._qwen3vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
return cls._ernie_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif "KeyeVL1_5" in hf_config.model_type:
return cls._keye_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
@classmethod
def _glm4v_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _qwen3vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
for _ in range(t)]
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _ernie_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for Ernie VL."""
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_conv_size, w // spatial_conv_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_grid_thw[mm_data_idx][0],
video_grid_thw[mm_data_idx][1],
video_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (t //
temporal_conv_size,
h //
spatial_conv_size,
w //
spatial_conv_size)
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _keye_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(
grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)* video_second_per_grid_t*
tokens_per_second).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if use_audio_in_video and idx > 0:
if src_item[idx] == vision_end_token_id and \
src_item[idx - 1] == audio_end_token_id:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif src_item[idx] == audio_start_token_id and \
src_item[idx - 1] == vision_start_token_id:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t) * 1 * tokens_per_second
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second)
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx).split(1,
dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id])
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = torch.cat(llm_pos_ids_list,
dim=1).max() + 1 - len(src_item)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@staticmethod
def _get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: list[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
len(t_index), -1, llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
len(t_index), llm_grid_h, -1).flatten())
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
@staticmethod
def get_next_input_positions(
......@@ -1254,68 +380,24 @@ class MRotaryEmbedding(RotaryEmbedding):
) -> list[list[int]]:
return [
list(
range(context_len + mrope_position_delta,
seq_len + mrope_position_delta)) for _ in range(3)
range(
context_len + mrope_position_delta, seq_len + mrope_position_delta
)
)
for _ in range(3)
]
@staticmethod
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
mrope_position_delta: int,
context_len: int, num_new_tokens: int):
values = np.arange(mrope_position_delta + context_len,
mrope_position_delta + context_len + num_new_tokens,
dtype=out.dtype)
out[:, out_offset:out_offset + num_new_tokens] = values
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[list[int], torch.Tensor],
video_second_per_grid_t: float,
) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
tokens_per_second)
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk,
audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates
def get_next_input_positions_tensor(
out: np.ndarray,
out_offset: int,
mrope_position_delta: int,
context_len: int,
num_new_tokens: int,
):
values = np.arange(
mrope_position_delta + context_len,
mrope_position_delta + context_len + num_new_tokens,
dtype=out.dtype,
)
out[:, out_offset : out_offset + num_new_tokens] = values
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, MutableSequence
from collections.abc import Iterable, Mapping, MutableSequence, Callable
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
......@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model
from .interfaces_base import is_pooling_model, VllmModel
if TYPE_CHECKING:
from vllm.config import VllmConfig
......@@ -81,10 +81,9 @@ class SupportsMultiModal(Protocol):
"""
...
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
"""
Returns multimodal embeddings generated from multimodal kwargs
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
Note:
......@@ -94,11 +93,11 @@ class SupportsMultiModal(Protocol):
"""
...
def get_language_model(self) -> torch.nn.Module:
def get_language_model(self) -> VllmModel:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
......@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol):
"""
...
@overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ...
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor: ...
def _get_text_embeddings(
self,
input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor],
*,
is_multimodal: Optional[Tensor],
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[Tensor] = None,
handle_oov_mm_token: bool = False,
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
...
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
@runtime_checkable
class SupportsMultiModalPruning(Protocol):
"""The interface required for models that support returning both input
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team.
# Copyright 2023 The vLLM team.
......@@ -22,6 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
import os
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
......@@ -48,7 +49,9 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend
# from vllm.attention.backends.registry import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
......@@ -106,6 +109,7 @@ from .utils import (
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
......@@ -143,18 +147,28 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
# self.proj = ReplicatedLinear(
# in_channels * math.prod(kernel_size),
# hidden_size,
# bias=True,
# return_bias=False,
# )
self.proj = nn.Conv3d(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
# if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
# x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size)
return x
......@@ -308,7 +322,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
......@@ -380,9 +393,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
head_size=head_dim, dtype=torch.get_default_dtype()
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
......@@ -571,8 +582,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
# if name.endswith("patch_embed.proj.weight"):
# loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
......@@ -811,7 +822,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if is_update_applied:
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
(
prompt_ids,
prompt_ids,
prompt,
mm_placeholders,
) = self._apply_prompt_updates(
prompt_ids,
......@@ -829,7 +841,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids, prompt, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
......@@ -837,8 +849,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_placeholders,
mm_item_counts,
)
return prompt_ids, mm_placeholders
return prompt_ids, prompt, mm_placeholders
def get_updates_use_audio_in_video(
self,
......@@ -1160,18 +1171,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config
......@@ -1375,7 +1379,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
.contiguous()
)
self._set_deepstack_input_embeds(deepstack_input_embeds)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
......@@ -1434,7 +1437,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
@classmethod
def get_mrope_input_positions(
self,
input_tokens: list[int],
......
......@@ -10,6 +10,7 @@ import torch
import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.config import VllmConfig
......@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
Note:
This updates ``inputs_embeds`` in place.
This updates `inputs_embeds` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
flattened.to(dtype=inputs_embeds.dtype))
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
raise ValueError("Error during masked scatter operation") from e
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
return inputs_embeds
@deprecated(
"`merge_multimodal_embeddings` has been replaced with "
"`SupportsMultiModal.get_input_embeddings` and will be "
"removed in v0.12."
)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, list[int]],
placeholder_token_id: int | list[int],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
`placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
......@@ -491,26 +479,32 @@ def merge_multimodal_embeddings(
input_ids for a correct embedding merge.
Note:
This updates ``inputs_embeds`` in place.
This updates `inputs_embeds` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(
placeholder_token_id,
pin_memory=is_pin_memory_available()).to(device=input_ids.device,
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
is_multimodal = isin_list(input_ids, placeholder_token_id)
else:
is_multimodal = input_ids == placeholder_token_id
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module:
......
......@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
......@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
mm_embeds = list[torch.Tensor]()
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_state = self.requests[req_id]
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
......@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
......@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions(
......@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
multimodal_embeddings=mm_embeds_req,
mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens,
))
assert req_state.mrope_positions is not None
)
)
req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(
scheduler_output.total_num_scheduled_tokens)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds
return mm_embeds, is_mm_embed
def _extract_encoder_inputs(
self,
......@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
......@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None,
is_multimodal=is_mm_embed,
)
# TODO(woosuk): Avoid the copy. Optimize.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
__version__ = "0.11.0"
__version_tuple__ = (0, 11, 0)
__hcu_version__ = f'0.11.0+das.opt1.alpha.6c015e7.dtk25041'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}",
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
......@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment