"vllm/vscode:/vscode.git/clone" did not exist on "93dc5a287086299a124e9f1f6fac75458ae0acbd"
Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
...@@ -439,7 +439,7 @@ class Qwen2Model(nn.Module): ...@@ -439,7 +439,7 @@ class Qwen2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -659,7 +659,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -659,7 +659,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1298,7 +1298,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1298,7 +1298,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1509,7 +1509,7 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1509,7 +1509,7 @@ class Qwen2_5_VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -451,7 +451,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -451,7 +451,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -408,7 +408,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -408,7 +408,7 @@ class Qwen2MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -633,7 +633,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -633,7 +633,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -79,7 +79,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -79,7 +79,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1432,7 +1432,7 @@ class Qwen2VLForConditionalGeneration( ...@@ -1432,7 +1432,7 @@ class Qwen2VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -314,7 +314,7 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -314,7 +314,7 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -688,7 +688,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -688,7 +688,7 @@ class Qwen3MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1033,7 +1033,7 @@ class Qwen3MoeForCausalLM( ...@@ -1033,7 +1033,7 @@ class Qwen3MoeForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1005,7 +1005,7 @@ class Qwen3NextModel(nn.Module): ...@@ -1005,7 +1005,7 @@ class Qwen3NextModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1240,7 +1240,7 @@ class Qwen3NextForCausalLM( ...@@ -1240,7 +1240,7 @@ class Qwen3NextForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -261,7 +261,7 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts): ...@@ -261,7 +261,7 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any from typing import Any
...@@ -104,10 +104,7 @@ from .utils import ( ...@@ -104,10 +104,7 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import get_vit_attn_backend
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1001,7 +998,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -1001,7 +998,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1822,7 +1819,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1822,7 +1819,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1867,323 +1864,268 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1867,323 +1864,268 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return loaded_weights return loaded_weights
def get_mrope_input_positions( def _compute_audio_token_count(self, audio_feature_length: int) -> int:
self, """Compute audio tokens from feature length using Qwen3-Omni formula."""
input_tokens: list[int], return _get_feat_extract_output_lengths(
mm_features: list[MultiModalFeatureSpec], torch.tensor([audio_feature_length])
) -> tuple[torch.Tensor, int]: ).item()
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
input_ids = torch.tensor(input_tokens) def _get_audio_for_video_mapping(
if input_ids is None or input_ids.ndim != 1: self, mm_features: list[MultiModalFeatureSpec]
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") ) -> tuple[dict[int, int], set[int]]:
"""
Map video offset -> paired audio_feature_length for use_audio_in_video.
seq_len = input_ids.shape[0] When use_audio_in_video=True, audio is interleaved within video.
The pairing is based on feature order in mm_features.
if isinstance(audio_feature_lengths, list): Returns:
audio_feature_lengths = torch.tensor( Tuple of (video_offset -> audio_feature_length mapping,
audio_feature_lengths, dtype=torch.long set of paired audio offsets to skip)
) """
videos_with_audio = [
if not len(second_per_grid_ts) and len(video_grid_thw): f
second_per_grid_ts = 2.0 for f in mm_features
second_per_grids = ( if f.modality == "video"
torch.ones(len(video_grid_thw), dtype=torch.float32) and f.data.get("use_audio_in_video")
* second_per_grid_ts and f.data["use_audio_in_video"].data.item()
) ]
else: audios = [f for f in mm_features if f.modality == "audio"]
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
mapping: dict[int, int] = {}
paired_audio_offsets: set[int] = set()
for i, video_f in enumerate(videos_with_audio):
if i < len(audios):
audio_len = audios[i].data["audio_feature_lengths"].data.item()
mapping[video_f.mm_position.offset] = audio_len
paired_audio_offsets.add(audios[i].mm_position.offset)
return mapping, paired_audio_offsets
def iter_mm_features(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
"""
Iterate over multimodal features sorted by position offset.
Yields: (offset, modality, feature_data) where feature_data contains:
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
"use_audio_in_video", "audio_feature_length"}
- audio: {"audio_feature_length"}
"""
config = self.config config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds position_id_per_seconds = config.position_id_per_seconds
vision_start_indices = torch.argwhere( sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
input_ids == vision_start_token_id audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
).squeeze(1) sorted_features
if vision_start_indices.numel() > 0:
vision_tokens = input_ids[vision_start_indices + 1]
else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
) )
llm_pos_ids_list: list[torch.Tensor] = [] for mm_feature in sorted_features:
st = 0 offset = mm_feature.mm_position.offset
image_idx = 0 modality = mm_feature.modality
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
for _ in range(multimodal_nums): if modality == "image":
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
if (image_token_id in input_tokens or video_token_id in input_tokens) and ( yield (
remain_videos > 0 or remain_images > 0 offset,
): "image",
ed_vision_start = input_tokens.index(vision_start_token_id, st) {
else: "grid_t": t,
ed_vision_start = len(input_tokens) + 1 "grid_h": h // spatial_merge_size,
if audio_token_id in input_tokens and remain_audios > 0: "grid_w": w // spatial_merge_size,
ed_audio_start = input_tokens.index(audio_start_token_id, st) "t_factor": position_id_per_seconds,
else: },
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
) )
llm_pos_ids = ( elif modality == "video":
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
+ st_idx second_per_grid_ts = 2.0
if mm_feature.data.get("second_per_grid_ts"):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
use_audio_in_video = bool(
mm_feature.data.get("use_audio_in_video")
and mm_feature.data["use_audio_in_video"].data.item()
) )
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 yield (
eos_len = 1 offset,
llm_pos_ids_list.append( "video",
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) {
+ st_idx "grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": second_per_grid_ts * position_id_per_seconds,
"use_audio_in_video": use_audio_in_video,
"audio_feature_length": audio_for_video.get(offset),
},
) )
st += text_len + bos_len + audio_len + eos_len elif modality == "audio":
if offset not in paired_audio_offsets:
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
yield offset, "audio", {"audio_feature_length": audio_len}
def _compute_interleaved_positions(
self, start_idx: int, data: dict[str, Any]
) -> tuple[np.ndarray, int]:
"""
Compute positions for interleaved video+audio using Qwen3 token-by-token
interleaving logic.
Returns: (position_ids [3, N], total_token_count)
"""
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
audio_feature_length = data["audio_feature_length"]
audio_len = self._compute_audio_token_count(audio_feature_length)
h_index = np.tile(
np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
).flatten()
w_index = np.tile(
np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
).flatten()
t_index_raw = np.arange(grid_t)
t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
t_index = np.repeat(t_index_scaled, grid_h * grid_w)
video_pos = np.stack([t_index, h_index, w_index]) + start_idx
audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx
video_t_values = video_pos[0]
audio_t_values = audio_pos[0]
pos_ids_list: list[np.ndarray] = []
video_idx, audio_idx = 0, 0
num_video = grid_t * grid_h * grid_w
while video_idx < num_video and audio_idx < audio_len:
if video_t_values[video_idx] <= audio_t_values[audio_idx]:
pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
video_idx += 1
else:
pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
audio_idx += 1 audio_idx += 1
remain_audios -= 1
elif ( if video_idx < num_video:
min_ed == ed_vision_start pos_ids_list.append(video_pos[:, video_idx:])
and input_ids[ed_vision_start + 1] == image_token_id if audio_idx < audio_len:
): pos_ids_list.append(audio_pos[:, audio_idx:])
text_len = min_ed - st
if text_len != 0: total_tokens = num_video + audio_len
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 return np.concatenate(pos_ids_list, axis=1), total_tokens
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long) def get_mrope_input_positions(
.view(1, -1) self,
.expand(3, -1) input_tokens: list[int],
+ st_idx mm_features: list[MultiModalFeatureSpec],
) ) -> tuple[torch.Tensor, int]:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 """Compute M-RoPE input positions using mm_features directly."""
bos_len = 1 seq_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) llm_pos_ids_list: list[np.ndarray] = []
+ st_idx st = 0
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 for offset, modality, data in self.iter_mm_features(mm_features):
grid_t = image_grid_thw[image_idx][0] text_len = offset - st
grid_hs = image_grid_thw[:, 1] st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t) * position_id_per_seconds if text_len > 0:
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
+ st_idx
)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
and use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
) )
audio_llm_pos_ids = ( st_idx += text_len
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
) llm_pos_ids_list.append(bos_pos)
grid_t = video_grid_thw[video_idx][0] st_idx += 1
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2] if modality == "audio":
t_index = ( audio_tokens = self._compute_audio_token_count(
torch.arange(grid_t) data["audio_feature_length"]
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
) )
video_llm_pos_ids = get_llm_pos_ids_for_vision( audio_pos = (
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
) )
video_data_index, audio_data_index = 0, 0 llm_pos_ids_list.append(audio_pos)
while ( st_idx = int(audio_pos.max()) + 1
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1] eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
): llm_pos_ids_list.append(eos_pos)
if ( st = offset + 1 + audio_tokens + 1
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index] elif modality == "image":
): grid_t = data["grid_t"]
llm_pos_ids_list.append( grid_h = data["grid_h"]
video_llm_pos_ids[ grid_w = data["grid_w"]
:, video_data_index : video_data_index + 1 t_factor = data["t_factor"]
]
) grid_indices = np.indices((grid_t, grid_h, grid_w))
video_data_index += 1 if t_factor != 1.0:
else: grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append( llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1 image_len = grid_t * grid_h * grid_w
] st_idx = int(llm_pos_ids_list[-1].max()) + 1
)
audio_data_index += 1 eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
if video_data_index < video_llm_pos_ids.shape[-1]: llm_pos_ids_list.append(eos_pos)
llm_pos_ids_list.append( st = offset + 1 + image_len + 1
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1] elif modality == "video":
] grid_t = data["grid_t"]
) grid_h = data["grid_h"]
if audio_data_index < audio_llm_pos_ids.shape[-1]: grid_w = data["grid_w"]
llm_pos_ids_list.append( t_factor = data["t_factor"]
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1] if not data["use_audio_in_video"]:
] grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
video_len = grid_t * grid_h * grid_w
st_idx = int(llm_pos_ids_list[-1].max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + video_len + 1
else:
audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
llm_pos_ids_list.append(audio_bos_pos)
pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
llm_pos_ids_list.append(pos_ids)
st_idx = int(pos_ids.max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
llm_pos_ids_list.append(eos_pos)
video_len = grid_t * grid_h * grid_w
audio_len = self._compute_audio_token_count(
data["audio_feature_length"]
) )
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) st = offset + 2 + video_len + audio_len + 2
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(eos_block)
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens): if st < seq_len:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st text_len = seq_len - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
+ st_idx
) )
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len: if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length") raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = llm_positions.max() + 1 - seq_len mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
......
...@@ -1122,7 +1122,7 @@ class Qwen3LLMModel(Qwen3Model): ...@@ -1122,7 +1122,7 @@ class Qwen3LLMModel(Qwen3Model):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -2004,7 +2004,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -2004,7 +2004,7 @@ class Qwen3VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -94,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -94,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -810,7 +810,7 @@ class QwenVLForConditionalGeneration( ...@@ -810,7 +810,7 @@ class QwenVLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -320,8 +320,9 @@ _MULTIMODAL_MODELS = { ...@@ -320,8 +320,9 @@ _MULTIMODAL_MODELS = {
), ),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"), "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501 "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
"GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": ( "GraniteSpeechForConditionalGeneration": (
"granite_speech", "granite_speech",
"GraniteSpeechForConditionalGeneration", "GraniteSpeechForConditionalGeneration",
...@@ -360,6 +361,7 @@ _MULTIMODAL_MODELS = { ...@@ -360,6 +361,7 @@ _MULTIMODAL_MODELS = {
), ),
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
"LightOnOCRForConditionalGeneration": ( "LightOnOCRForConditionalGeneration": (
"lightonocr", "lightonocr",
"LightOnOCRForConditionalGeneration", "LightOnOCRForConditionalGeneration",
...@@ -473,6 +475,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -473,6 +475,7 @@ _SPECULATIVE_DECODING_MODELS = {
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"), "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
"GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"), "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
......
...@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module): ...@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -467,7 +467,7 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -467,7 +467,7 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -898,7 +898,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -898,7 +898,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -465,7 +465,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -465,7 +465,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module): ...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment