Commit b12c902b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'cx/v0.11.0-dev' into v0.11.0-dev-omni

parents c16e075a f39afa4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers import PretrainedConfig
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -62,10 +55,8 @@ def _triton_mrope_forward( ...@@ -62,10 +55,8 @@ def _triton_mrope_forward(
# Updated offsets for half head_dim # Updated offsets for half head_dim
cos_offsets = tl.arange(0, pad_hd // 2) cos_offsets = tl.arange(0, pad_hd // 2)
if is_interleaved: if is_interleaved:
h_mask = (((cos_offsets % 3) == 1) & h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
(cos_offsets <= 3 * mrope_section_h)) w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
w_mask = (((cos_offsets % 3) == 2) &
(cos_offsets <= 3 * mrope_section_w))
t_mask = ~(h_mask | w_mask) t_mask = ~(h_mask | w_mask)
else: else:
t_end = mrope_section_t t_end = mrope_section_t
...@@ -89,21 +80,25 @@ def _triton_mrope_forward( ...@@ -89,21 +80,25 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately # program instance (i.e. for the current token) separately
# #################################################################### # ####################################################################
# left half of the head # left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( first_half_q_offsets = (
0, pad_hd // 2)[None, :] tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( )
0, pad_hd // 2)[None, :] first_half_k_offsets = (
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
0, pad_hd // 2)[None, :] < rd // 2) )
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
0, pad_hd // 2)[None, :] < rd // 2) tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
mask=first_q_mask, sin_row.dtype
other=0).to(sin_row.dtype) )
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
mask=first_k_mask, sin_row.dtype
other=0).to(sin_row.dtype) )
# right half of the head # right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2) second_half_q_offsets = first_half_q_offsets + (rd // 2)
...@@ -111,12 +106,12 @@ def _triton_mrope_forward( ...@@ -111,12 +106,12 @@ def _triton_mrope_forward(
second_q_mask = first_q_mask second_q_mask = first_q_mask
second_k_mask = first_k_mask second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
mask=second_q_mask, sin_row.dtype
other=0).to(sin_row.dtype) )
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
mask=second_k_mask, sin_row.dtype
other=0).to(sin_row.dtype) )
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size, # Since cos and sin are now half-size,
...@@ -168,7 +163,7 @@ def triton_mrope( ...@@ -168,7 +163,7 @@ def triton_mrope(
cos = cos.contiguous() cos = cos.contiguous()
sin = sin.contiguous() sin = sin.contiguous()
_triton_mrope_forward[(n_row, )]( _triton_mrope_forward[(n_row,)](
q, q,
k, k,
cos, cos,
...@@ -189,15 +184,14 @@ def triton_mrope( ...@@ -189,15 +184,14 @@ def triton_mrope(
return q, k return q, k
def apply_interleaved_rope(x: torch.Tensor, def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings. """Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity. interleaved [THTHWHTHW...TT], preserving frequency continuity.
""" """
x_t = x[0].clone() x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
return x_t return x_t
...@@ -212,17 +206,16 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -212,17 +206,16 @@ class MRotaryEmbedding(RotaryEmbedding):
base: float, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
mrope_section: Optional[list[int]] = None, mrope_section: list[int] | None = None,
mrope_interleaved: bool = False, mrope_interleaved: bool = False,
# YaRN parameters. # YaRN parameters.
*, *,
scaling_factor: Optional[float] = None, scaling_factor: float | None = None,
extrapolation_factor: float = 1, extrapolation_factor: float = 1,
attn_factor: float = 1, attn_factor: float = 1,
beta_fast: int = 32, beta_fast: int = 32,
beta_slow: int = 1, beta_slow: int = 1,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor self.attn_factor = attn_factor
...@@ -230,8 +223,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -230,8 +223,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.beta_slow = beta_slow self.beta_slow = beta_slow
if self.scaling_factor is not None: if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation # Get n-d magnitude scaling corrected for interpolation
self.mscale = float( self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
yarn_get_mscale(self.scaling_factor) * attn_factor)
else: else:
self.mscale = 1.0 self.mscale = 1.0
...@@ -239,8 +231,14 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -239,8 +231,14 @@ class MRotaryEmbedding(RotaryEmbedding):
# the input video. We enlarge max_position_embeddings to 4 times to get # the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache. # a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4 self.cache_max_position_num = max_position_embeddings * 4
super().__init__(head_size, rotary_dim, self.cache_max_position_num, super().__init__(
base, is_neox_style, dtype) head_size,
rotary_dim,
self.cache_max_position_num,
base,
is_neox_style,
dtype,
)
self.mrope_section = mrope_section self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved self.mrope_interleaved = mrope_interleaved
...@@ -261,9 +259,9 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -261,9 +259,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: torch.Tensor | None = None,
offsets: Optional[torch.Tensor] = None, offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward(). """PyTorch-native implementation equivalent to forward().
Args: Args:
...@@ -286,31 +284,27 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -286,31 +284,27 @@ class MRotaryEmbedding(RotaryEmbedding):
cos = apply_interleaved_rope(cos, self.mrope_section) cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section) sin = apply_interleaved_rope(sin, self.mrope_section)
else: else:
cos = torch.cat([ cos = torch.cat(
m[i] for i, m in enumerate( [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
cos.split(self.mrope_section, dim=-1)) dim=-1,
], )
dim=-1) sin = torch.cat(
sin = torch.cat([ [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
m[i] for i, m in enumerate( dim=-1,
sin.split(self.mrope_section, dim=-1)) )
],
dim=-1)
query_shape = query.shape query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
...@@ -318,10 +312,9 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -318,10 +312,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: torch.Tensor | None = None,
offsets: Optional[torch.Tensor] = None, offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
assert positions.ndim == 1 or positions.ndim == 2 assert positions.ndim == 1 or positions.ndim == 2
assert key is not None assert key is not None
...@@ -348,17 +341,15 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -348,17 +341,15 @@ class MRotaryEmbedding(RotaryEmbedding):
return q.reshape(query_shape), k.reshape(key_shape) return q.reshape(query_shape), k.reshape(key_shape)
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
...@@ -366,885 +357,20 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -366,885 +357,20 @@ class MRotaryEmbedding(RotaryEmbedding):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: torch.Tensor | None = None,
offsets: Optional[torch.Tensor] = None, offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets) return self.forward_native(positions, query, key, offsets)
def forward_cpu( def forward_cpu(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: torch.Tensor | None = None,
offsets: Optional[torch.Tensor] = None, offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets) return self.forward_native(positions, query, key, offsets)
@classmethod
def get_input_positions(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]:
"""Get mrope input positions and delta value."""
image_grid_thw = [] if image_grid_thw is None else image_grid_thw
video_grid_thw = [] if video_grid_thw is None else video_grid_thw
second_per_grid_ts = [] if second_per_grid_ts is None else \
second_per_grid_ts
llm_positions, mrope_position_delta = \
cls.get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
return llm_positions.tolist(), mrope_position_delta
@classmethod
def get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni":
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
elif hf_config.model_type in ["glm4v", "glm4v_moe"]:
return cls._glm4v_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
return cls._qwen3vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
return cls._ernie_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif "KeyeVL1_5" in hf_config.model_type:
return cls._keye_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
@classmethod
def _glm4v_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _qwen3vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
for _ in range(t)]
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _ernie_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for Ernie VL."""
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_conv_size, w // spatial_conv_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_grid_thw[mm_data_idx][0],
video_grid_thw[mm_data_idx][1],
video_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (t //
temporal_conv_size,
h //
spatial_conv_size,
w //
spatial_conv_size)
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _keye_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(
grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)* video_second_per_grid_t*
tokens_per_second).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if use_audio_in_video and idx > 0:
if src_item[idx] == vision_end_token_id and \
src_item[idx - 1] == audio_end_token_id:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif src_item[idx] == audio_start_token_id and \
src_item[idx - 1] == vision_start_token_id:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t) * 1 * tokens_per_second
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second)
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx).split(1,
dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id])
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = torch.cat(llm_pos_ids_list,
dim=1).max() + 1 - len(src_item)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@staticmethod
def _get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: list[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
len(t_index), -1, llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
len(t_index), llm_grid_h, -1).flatten())
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
@staticmethod @staticmethod
def get_next_input_positions( def get_next_input_positions(
...@@ -1254,68 +380,24 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1254,68 +380,24 @@ class MRotaryEmbedding(RotaryEmbedding):
) -> list[list[int]]: ) -> list[list[int]]:
return [ return [
list( list(
range(context_len + mrope_position_delta, range(
seq_len + mrope_position_delta)) for _ in range(3) context_len + mrope_position_delta, seq_len + mrope_position_delta
)
)
for _ in range(3)
] ]
@staticmethod @staticmethod
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, def get_next_input_positions_tensor(
out: np.ndarray,
out_offset: int,
mrope_position_delta: int, mrope_position_delta: int,
context_len: int, num_new_tokens: int): context_len: int,
num_new_tokens: int,
values = np.arange(mrope_position_delta + context_len, ):
values = np.arange(
mrope_position_delta + context_len,
mrope_position_delta + context_len + num_new_tokens, mrope_position_delta + context_len + num_new_tokens,
dtype=out.dtype) dtype=out.dtype,
out[:, out_offset:out_offset + num_new_tokens] = values )
out[:, out_offset : out_offset + num_new_tokens] = values
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[list[int], torch.Tensor],
video_second_per_grid_t: float,
) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
tokens_per_second)
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk,
audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, MutableSequence from collections.abc import Iterable, Mapping, MutableSequence, Callable
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable) Union, overload, runtime_checkable)
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import is_pooling_model, VllmModel
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -81,8 +81,7 @@ class SupportsMultiModal(Protocol): ...@@ -81,8 +81,7 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
**kwargs: object) -> MultiModalEmbeddings:
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
...@@ -94,7 +93,7 @@ class SupportsMultiModal(Protocol): ...@@ -94,7 +93,7 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> VllmModel:
""" """
Returns the underlying language model used for text generation. Returns the underlying language model used for text generation.
...@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol): ...@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol):
""" """
... ...
@overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ...
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor: ...
def _get_text_embeddings(
self,
input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor],
*,
is_multimodal: Optional[Tensor],
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[Tensor] = None,
handle_oov_mm_token: bool = False,
) -> Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Apply token embeddings to `input_ids`.
input_ids and the multimodal embeddings generated from multimodal
kwargs. If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
""" """
... from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
@runtime_checkable @runtime_checkable
class SupportsMultiModalPruning(Protocol): class SupportsMultiModalPruning(Protocol):
"""The interface required for models that support returning both input """The interface required for models that support returning both input
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team. # Copyright 2025 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
import os import os
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
...@@ -48,7 +49,9 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( ...@@ -48,7 +49,9 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
) )
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend
# from vllm.attention.backends.registry import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -106,6 +109,7 @@ from .utils import ( ...@@ -106,6 +109,7 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
conv3d_to_linear_weight, conv3d_to_linear_weight,
get_llm_pos_ids_for_vision, get_llm_pos_ids_for_vision,
...@@ -143,18 +147,28 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -143,18 +147,28 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size), # self.proj = ReplicatedLinear(
# in_channels * math.prod(kernel_size),
# hidden_size,
# bias=True,
# return_bias=False,
# )
self.proj = nn.Conv3d(
in_channels,
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape L, C = x.shape
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1': x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = x.to(memory_format=torch.channels_last_3d) # if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = self.proj(x) # x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size)
return x return x
...@@ -308,7 +322,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -308,7 +322,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
...@@ -380,9 +393,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -380,9 +393,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
) )
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim, dtype=torch.get_default_dtype()
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype() torch.get_default_dtype()
...@@ -571,8 +582,8 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -571,8 +582,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"): # if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight) # loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -812,6 +823,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( ...@@ -812,6 +823,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video) prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
( (
prompt_ids, prompt_ids,
prompt,
mm_placeholders, mm_placeholders,
) = self._apply_prompt_updates( ) = self._apply_prompt_updates(
prompt_ids, prompt_ids,
...@@ -829,7 +841,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( ...@@ -829,7 +841,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts, mm_item_counts,
) )
else: else:
prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, prompt, mm_placeholders = self._apply_prompt_updates(
prompt_ids, prompt_ids,
mm_prompt_updates, mm_prompt_updates,
) )
...@@ -837,8 +849,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( ...@@ -837,8 +849,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_placeholders, mm_placeholders,
mm_item_counts, mm_item_counts,
) )
return prompt_ids, prompt, mm_placeholders
return prompt_ids, mm_placeholders
def get_updates_use_audio_in_video( def get_updates_use_audio_in_video(
self, self,
...@@ -1160,18 +1171,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1160,18 +1171,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer( self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config, vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1375,7 +1379,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1375,7 +1379,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
.contiguous() .contiguous()
) )
self._set_deepstack_input_embeds(deepstack_input_embeds) self._set_deepstack_input_embeds(deepstack_input_embeds)
inputs_embeds = _merge_multimodal_embeddings( inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
...@@ -1435,6 +1438,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1435,6 +1438,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return loaded_weights return loaded_weights
@classmethod
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: ...@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
return " + ".join( return " + ".join(
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in `inputs_embeds` corresponding to placeholder tokens in
``input_ids``. `input_ids`.
Note: Note:
This updates ``inputs_embeds`` in place. This updates `inputs_embeds` in place.
""" """
flattened = _flatten_embeddings(multimodal_embeddings) if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try: try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened. # For debugging
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
flattened.to(dtype=inputs_embeds.dtype))
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
except RuntimeError as e: except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item() num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens: if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings) expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} " f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders" f"multimodal tokens to {num_expected_tokens} placeholders"
) from e ) from e
else:
raise ValueError("Error during masked scatter operation") from e raise ValueError("Error during masked scatter operation") from e
return inputs_embeds return inputs_embeds
def embed_multimodal( @deprecated(
input_ids: torch.Tensor, "`merge_multimodal_embeddings` has been replaced with "
multimodal_token_id: int, "`SupportsMultiModal.get_input_embeddings` and will be "
get_text_embeds: Callable[[torch.Tensor], torch.Tensor], "removed in v0.12."
multimodal_embeds: NestedTensors, )
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings( def merge_multimodal_embeddings(
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, list[int]], placeholder_token_id: int | list[int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in `inputs_embeds` corresponding to placeholder tokens in
``input_ids``. `input_ids`.
``placeholder_token_id`` can be a list of token ids (e.g, token ids `placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering. slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
...@@ -491,26 +479,32 @@ def merge_multimodal_embeddings( ...@@ -491,26 +479,32 @@ def merge_multimodal_embeddings(
input_ids for a correct embedding merge. input_ids for a correct embedding merge.
Note: Note:
This updates ``inputs_embeds`` in place. This updates `inputs_embeds` in place.
""" """
if isinstance(placeholder_token_id, list): if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor( is_multimodal = isin_list(input_ids, placeholder_token_id)
placeholder_token_id, else:
pin_memory=is_pin_memory_available()).to(device=input_ids.device, is_multimodal = input_ids == placeholder_token_id
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds, inputs_embeds,
(input_ids == placeholder_token_id), multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings, is_multimodal=is_multimodal,
) )
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: def __call__(self, prefix: str) -> torch.nn.Module:
......
...@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32) dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64) dtype=torch.int64)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
...@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0, shift_computed_tokens: int = 0,
) -> list[torch.Tensor]: ) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
mm_embeds = list[torch.Tensor]()
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
should_sync_mrope_positions = False should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = [] mm_embeds_req: list[torch.Tensor] = []
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_id]
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = \ num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features: for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position pos_info = mm_feature.mm_position
start_pos = pos_info.offset start_pos = pos_info.offset
...@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None) encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\ assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None: if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx] is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds_item = gather_mm_placeholders( mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx], encoder_output[start_idx:end_idx],
is_embed=is_embed, is_embed=is_embed,
...@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req.append(mm_embeds_item) mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope: if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = ( mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions( self.model.recompute_mrope_positions(
...@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
multimodal_embeddings=mm_embeds_req, multimodal_embeddings=mm_embeds_req,
mrope_positions=req_state.mrope_positions, mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens, num_computed_tokens=req_state.num_computed_tokens,
)) )
assert req_state.mrope_positions is not None )
req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req) mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions: if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu( self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
scheduler_output.total_num_scheduled_tokens)
return mm_embeds return mm_embeds, is_mm_embed
def _extract_encoder_inputs( def _extract_encoder_inputs(
self, self,
...@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and not self.model_config.is_encoder_decoder): and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
...@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds_scheduled = self.model.get_input_embeddings( inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens], input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None, multimodal_embeddings=mm_embeds or None,
is_multimodal=is_mm_embed,
) )
# TODO(woosuk): Avoid the copy. Optimize. # TODO(woosuk): Avoid the copy. Optimize.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try: try:
from ._version import __version__, __version_tuple__ __version__ = "0.11.0"
__version_tuple__ = (0, 11, 0)
__hcu_version__ = f'0.11.0+das.opt1.alpha.6c015e7.dtk25041'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
import warnings import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str): def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version. '''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version. Return True if version_str matches the previous minor version.
...@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str): ...@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'. supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version. Used for --show-hidden-metrics-for-version.
""" '''
# Match anything if this is a dev tree # Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0): if __version_tuple__[0:2] == (0, 0):
return True return True
# Note - this won't do the right thing when we release 1.0! # Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0 # assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version(): def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number.""" '''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine" # In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment