Commit df704163 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (models)

parent d7db129a
......@@ -342,7 +342,7 @@ class Plamo3Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -412,7 +412,7 @@ class Plamo3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -434,4 +434,4 @@ class Plamo3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -243,7 +243,7 @@ class QWenModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -425,7 +425,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -439,7 +439,7 @@ class Qwen2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -659,7 +659,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -1298,7 +1298,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1330,4 +1330,4 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
language_model="language_model",
connector="merger.",
tower_model=["visual.", "audio_tower."],
)
)
\ No newline at end of file
......@@ -451,7 +451,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -408,7 +408,7 @@ class Qwen2MoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -633,7 +633,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -79,7 +79,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -1432,7 +1432,7 @@ class Qwen2VLForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -314,7 +314,7 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -252,11 +252,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
final_hidden_states
)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
......@@ -688,7 +683,7 @@ class Qwen3MoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1033,7 +1028,7 @@ class Qwen3MoeForCausalLM(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -102,7 +102,6 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......@@ -1005,7 +1004,7 @@ class Qwen3NextModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1240,7 +1239,7 @@ class Qwen3NextForCausalLM(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -261,7 +261,7 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
......@@ -292,4 +292,4 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))
return loader.load_weights(remap_weight_names(weights))
\ No newline at end of file
......@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Any
......@@ -104,7 +104,10 @@ from .utils import (
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import get_vit_attn_backend
from .vision import (
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
logger = init_logger(__name__)
......@@ -998,7 +1001,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1819,7 +1822,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1864,268 +1867,323 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return loaded_weights
def _compute_audio_token_count(self, audio_feature_length: int) -> int:
"""Compute audio tokens from feature length using Qwen3-Omni formula."""
return _get_feat_extract_output_lengths(
torch.tensor([audio_feature_length])
).item()
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
input_ids = torch.tensor(input_tokens)
if input_ids is None or input_ids.ndim != 1:
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
def _get_audio_for_video_mapping(
self, mm_features: list[MultiModalFeatureSpec]
) -> tuple[dict[int, int], set[int]]:
"""
Map video offset -> paired audio_feature_length for use_audio_in_video.
seq_len = input_ids.shape[0]
When use_audio_in_video=True, audio is interleaved within video.
The pairing is based on feature order in mm_features.
if isinstance(audio_feature_lengths, list):
audio_feature_lengths = torch.tensor(
audio_feature_lengths, dtype=torch.long
)
Returns:
Tuple of (video_offset -> audio_feature_length mapping,
set of paired audio offsets to skip)
"""
videos_with_audio = [
f
for f in mm_features
if f.modality == "video"
and f.data.get("use_audio_in_video")
and f.data["use_audio_in_video"].data.item()
]
audios = [f for f in mm_features if f.modality == "audio"]
mapping: dict[int, int] = {}
paired_audio_offsets: set[int] = set()
for i, video_f in enumerate(videos_with_audio):
if i < len(audios):
audio_len = audios[i].data["audio_feature_lengths"].data.item()
mapping[video_f.mm_position.offset] = audio_len
paired_audio_offsets.add(audios[i].mm_position.offset)
return mapping, paired_audio_offsets
def iter_mm_features(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
"""
Iterate over multimodal features sorted by position offset.
if not len(second_per_grid_ts) and len(video_grid_thw):
second_per_grid_ts = 2.0
second_per_grids = (
torch.ones(len(video_grid_thw), dtype=torch.float32)
* second_per_grid_ts
)
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
Yields: (offset, modality, feature_data) where feature_data contains:
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
"use_audio_in_video", "audio_feature_length"}
- audio: {"audio_feature_length"}
"""
config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
sorted_features
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = input_ids[vision_start_indices + 1]
else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
for mm_feature in sorted_features:
offset = mm_feature.mm_position.offset
modality = mm_feature.modality
llm_pos_ids_list: list[torch.Tensor] = []
st = 0
image_idx = 0
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
if modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
yield (
offset,
"image",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": position_id_per_seconds,
},
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
elif modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 2.0
if mm_feature.data.get("second_per_grid_ts"):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
use_audio_in_video = bool(
mm_feature.data.get("use_audio_in_video")
and mm_feature.data["use_audio_in_video"].data.item()
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
yield (
offset,
"video",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": second_per_grid_ts * position_id_per_seconds,
"use_audio_in_video": use_audio_in_video,
"audio_feature_length": audio_for_video.get(offset),
},
llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
elif modality == "audio":
if offset not in paired_audio_offsets:
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
yield offset, "audio", {"audio_feature_length": audio_len}
def _compute_interleaved_positions(
self, start_idx: int, data: dict[str, Any]
) -> tuple[np.ndarray, int]:
"""
Compute positions for interleaved video+audio using Qwen3 token-by-token
interleaving logic.
Returns: (position_ids [3, N], total_token_count)
"""
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
audio_feature_length = data["audio_feature_length"]
audio_len = self._compute_audio_token_count(audio_feature_length)
h_index = np.tile(
np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
).flatten()
w_index = np.tile(
np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
).flatten()
t_index_raw = np.arange(grid_t)
t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
t_index = np.repeat(t_index_scaled, grid_h * grid_w)
video_pos = np.stack([t_index, h_index, w_index]) + start_idx
audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx
video_t_values = video_pos[0]
audio_t_values = audio_pos[0]
pos_ids_list: list[np.ndarray] = []
video_idx, audio_idx = 0, 0
num_video = grid_t * grid_h * grid_w
while video_idx < num_video and audio_idx < audio_len:
if video_t_values[video_idx] <= audio_t_values[audio_idx]:
pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
video_idx += 1
else:
pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
if video_idx < num_video:
pos_ids_list.append(video_pos[:, video_idx:])
if audio_idx < audio_len:
pos_ids_list.append(audio_pos[:, audio_idx:])
total_tokens = num_video + audio_len
return np.concatenate(pos_ids_list, axis=1), total_tokens
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Compute M-RoPE input positions using mm_features directly."""
seq_len = len(input_tokens)
llm_pos_ids_list: list[np.ndarray] = []
st = 0
for offset, modality, data in self.iter_mm_features(mm_features):
text_len = offset - st
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
if text_len > 0:
remain_audios -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == image_token_id
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx += text_len
bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(bos_pos)
st_idx += 1
if modality == "audio":
audio_tokens = self._compute_audio_token_count(
data["audio_feature_length"]
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t) * position_id_per_seconds
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
audio_pos = (
np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(audio_pos)
st_idx = int(audio_pos.max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + audio_tokens + 1
elif modality == "image":
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
image_len = grid_t * grid_h * grid_w
st_idx = int(llm_pos_ids_list[-1].max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + image_len + 1
elif modality == "video":
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
if not data["use_audio_in_video"]:
grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
video_len = grid_t * grid_h * grid_w
st_idx = int(llm_pos_ids_list[-1].max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + video_len + 1
else:
audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
llm_pos_ids_list.append(audio_bos_pos)
pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
llm_pos_ids_list.append(pos_ids)
st_idx = int(pos_ids.max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
llm_pos_ids_list.append(eos_pos)
video_len = grid_t * grid_h * grid_w
audio_len = self._compute_audio_token_count(
data["audio_feature_length"]
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
and use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st = offset + 2 + video_len + audio_len + 2
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
video_llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(eos_block)
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < seq_len:
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
text_len = seq_len - st
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
return torch.from_numpy(llm_positions), mrope_position_delta
mrope_position_delta = llm_positions.max() + 1 - seq_len
return llm_positions, mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys:
"""
......@@ -2135,4 +2193,4 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
language_model="language_model",
connector="visual.merger",
tower_model=["visual.", "audio_tower."],
)
)
\ No newline at end of file
......@@ -1122,7 +1122,7 @@ class Qwen3LLMModel(Qwen3Model):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -2004,7 +2004,7 @@ class Qwen3VLForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -94,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -476,4 +476,4 @@ class Qwen3VLMoeForConditionalGeneration(
)
# Set MoE hyperparameters
self.set_moe_parameters()
self.set_moe_parameters()
\ No newline at end of file
......@@ -810,7 +810,7 @@ class QwenVLForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -321,8 +321,8 @@ _MULTIMODAL_MODELS = {
),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
"GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": (
"granite_speech",
......@@ -476,7 +476,6 @@ _SPECULATIVE_DECODING_MODELS = {
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
"GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
"MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
......
......@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -467,7 +467,7 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -489,4 +489,4 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -898,7 +898,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -944,4 +944,4 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"track_token",
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -465,7 +465,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -481,4 +481,4 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment