"tests/vscode:/vscode.git/clone" did not exist on "5135c321b06888cbe2708fa0a601d62165269607"
Commit c80f5968 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.15.1-dev' into v0.15.1-dev

# Conflicts:
#	vllm/model_executor/layers/fused_moe/config.py
#	vllm/model_executor/layers/fused_moe/layer.py
#	vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
parents 74306deb 530e785f
...@@ -314,7 +314,7 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -314,7 +314,7 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -252,11 +252,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -252,11 +252,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
final_hidden_states final_hidden_states
) )
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d # return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
...@@ -688,7 +683,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -688,7 +683,7 @@ class Qwen3MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1033,7 +1028,7 @@ class Qwen3MoeForCausalLM( ...@@ -1033,7 +1028,7 @@ class Qwen3MoeForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -102,7 +102,6 @@ KVCache = tuple[torch.Tensor, torch.Tensor] ...@@ -102,7 +102,6 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module): class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -1005,7 +1004,7 @@ class Qwen3NextModel(nn.Module): ...@@ -1005,7 +1004,7 @@ class Qwen3NextModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1240,7 +1239,7 @@ class Qwen3NextForCausalLM( ...@@ -1240,7 +1239,7 @@ class Qwen3NextForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -261,7 +261,7 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts): ...@@ -261,7 +261,7 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
...@@ -292,4 +292,4 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts): ...@@ -292,4 +292,4 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
yield name, weight yield name, weight
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights)) return loader.load_weights(remap_weight_names(weights))
\ No newline at end of file
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any from typing import Any
...@@ -104,7 +104,10 @@ from .utils import ( ...@@ -104,7 +104,10 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend from .vision import (
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -998,7 +1001,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -998,7 +1001,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1819,7 +1822,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1819,7 +1822,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1864,268 +1867,323 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1864,268 +1867,323 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return loaded_weights return loaded_weights
def _compute_audio_token_count(self, audio_feature_length: int) -> int: def get_mrope_input_positions(
"""Compute audio tokens from feature length using Qwen3-Omni formula.""" self,
return _get_feat_extract_output_lengths( input_tokens: list[int],
torch.tensor([audio_feature_length]) mm_features: list[MultiModalFeatureSpec],
).item() ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
input_ids = torch.tensor(input_tokens)
if input_ids is None or input_ids.ndim != 1:
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
def _get_audio_for_video_mapping( seq_len = input_ids.shape[0]
self, mm_features: list[MultiModalFeatureSpec]
) -> tuple[dict[int, int], set[int]]:
"""
Map video offset -> paired audio_feature_length for use_audio_in_video.
When use_audio_in_video=True, audio is interleaved within video. if isinstance(audio_feature_lengths, list):
The pairing is based on feature order in mm_features. audio_feature_lengths = torch.tensor(
audio_feature_lengths, dtype=torch.long
)
Returns: if not len(second_per_grid_ts) and len(video_grid_thw):
Tuple of (video_offset -> audio_feature_length mapping, second_per_grid_ts = 2.0
set of paired audio offsets to skip) second_per_grids = (
""" torch.ones(len(video_grid_thw), dtype=torch.float32)
videos_with_audio = [ * second_per_grid_ts
f )
for f in mm_features else:
if f.modality == "video" second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
and f.data.get("use_audio_in_video")
and f.data["use_audio_in_video"].data.item()
]
audios = [f for f in mm_features if f.modality == "audio"]
mapping: dict[int, int] = {}
paired_audio_offsets: set[int] = set()
for i, video_f in enumerate(videos_with_audio):
if i < len(audios):
audio_len = audios[i].data["audio_feature_lengths"].data.item()
mapping[video_f.mm_position.offset] = audio_len
paired_audio_offsets.add(audios[i].mm_position.offset)
return mapping, paired_audio_offsets
def iter_mm_features(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
"""
Iterate over multimodal features sorted by position offset.
Yields: (offset, modality, feature_data) where feature_data contains:
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
"use_audio_in_video", "audio_feature_length"}
- audio: {"audio_feature_length"}
"""
config = self.config config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds position_id_per_seconds = config.position_id_per_seconds
sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset) vision_start_indices = torch.argwhere(
audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping( input_ids == vision_start_token_id
sorted_features ).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = input_ids[vision_start_indices + 1]
else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
) )
for mm_feature in sorted_features: llm_pos_ids_list: list[torch.Tensor] = []
offset = mm_feature.mm_position.offset st = 0
modality = mm_feature.modality image_idx = 0
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
if modality == "image": for _ in range(multimodal_nums):
t, h, w = mm_feature.data["image_grid_thw"].data.tolist() st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
yield ( if (image_token_id in input_tokens or video_token_id in input_tokens) and (
offset, remain_videos > 0 or remain_images > 0
"image", ):
{ ed_vision_start = input_tokens.index(vision_start_token_id, st)
"grid_t": t, else:
"grid_h": h // spatial_merge_size, ed_vision_start = len(input_tokens) + 1
"grid_w": w // spatial_merge_size, if audio_token_id in input_tokens and remain_audios > 0:
"t_factor": position_id_per_seconds, ed_audio_start = input_tokens.index(audio_start_token_id, st)
}, else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
) )
elif modality == "video": st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
t, h, w = mm_feature.data["video_grid_thw"].data.tolist() audio_len = _get_feat_extract_output_lengths(
second_per_grid_ts = 2.0 audio_feature_lengths[audio_idx]
if mm_feature.data.get("second_per_grid_ts"):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
use_audio_in_video = bool(
mm_feature.data.get("use_audio_in_video")
and mm_feature.data["use_audio_in_video"].data.item()
) )
llm_pos_ids = (
yield ( torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
offset, + st_idx
"video",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": second_per_grid_ts * position_id_per_seconds,
"use_audio_in_video": use_audio_in_video,
"audio_feature_length": audio_for_video.get(offset),
},
) )
elif modality == "audio": llm_pos_ids_list.append(llm_pos_ids)
if offset not in paired_audio_offsets: st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = mm_feature.data["audio_feature_lengths"].data.item() eos_len = 1
yield offset, "audio", {"audio_feature_length": audio_len} llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
def _compute_interleaved_positions( + st_idx
self, start_idx: int, data: dict[str, Any] )
) -> tuple[np.ndarray, int]: st += text_len + bos_len + audio_len + eos_len
"""
Compute positions for interleaved video+audio using Qwen3 token-by-token
interleaving logic.
Returns: (position_ids [3, N], total_token_count)
"""
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
audio_feature_length = data["audio_feature_length"]
audio_len = self._compute_audio_token_count(audio_feature_length)
h_index = np.tile(
np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
).flatten()
w_index = np.tile(
np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
).flatten()
t_index_raw = np.arange(grid_t)
t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
t_index = np.repeat(t_index_scaled, grid_h * grid_w)
video_pos = np.stack([t_index, h_index, w_index]) + start_idx
audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx
video_t_values = video_pos[0]
audio_t_values = audio_pos[0]
pos_ids_list: list[np.ndarray] = []
video_idx, audio_idx = 0, 0
num_video = grid_t * grid_h * grid_w
while video_idx < num_video and audio_idx < audio_len:
if video_t_values[video_idx] <= audio_t_values[audio_idx]:
pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
video_idx += 1
else:
pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
audio_idx += 1 audio_idx += 1
remain_audios -= 1
if video_idx < num_video: elif (
pos_ids_list.append(video_pos[:, video_idx:]) min_ed == ed_vision_start
if audio_idx < audio_len: and input_ids[ed_vision_start + 1] == image_token_id
pos_ids_list.append(audio_pos[:, audio_idx:]) ):
text_len = min_ed - st
total_tokens = num_video + audio_len if text_len != 0:
return np.concatenate(pos_ids_list, axis=1), total_tokens st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
def get_mrope_input_positions( torch.arange(text_len, dtype=torch.long)
self, .view(1, -1)
input_tokens: list[int], .expand(3, -1)
mm_features: list[MultiModalFeatureSpec], + st_idx
) -> tuple[torch.Tensor, int]: )
"""Compute M-RoPE input positions using mm_features directly.""" st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
seq_len = len(input_tokens) bos_len = 1
llm_pos_ids_list: list[np.ndarray] = []
st = 0
for offset, modality, data in self.iter_mm_features(mm_features):
text_len = offset - st
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
if text_len > 0:
llm_pos_ids_list.append( llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
) )
st_idx += text_len st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = image_grid_thw[image_idx][0]
bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1)) grid_hs = image_grid_thw[:, 1]
llm_pos_ids_list.append(bos_pos) grid_ws = image_grid_thw[:, 2]
st_idx += 1 t_index = torch.arange(grid_t) * position_id_per_seconds
llm_pos_ids = get_llm_pos_ids_for_vision(
if modality == "audio": st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
audio_tokens = self._compute_audio_token_count(
data["audio_feature_length"]
) )
audio_pos = ( image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
) )
llm_pos_ids_list.append(audio_pos) st += text_len + bos_len + image_len + eos_len
st_idx = int(audio_pos.max()) + 1 image_idx += 1
remain_images -= 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1)) elif (
llm_pos_ids_list.append(eos_pos) min_ed == ed_vision_start
st = offset + 1 + audio_tokens + 1 and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
elif modality == "image": ):
grid_t = data["grid_t"] text_len = min_ed - st
grid_h = data["grid_h"] if text_len != 0:
grid_w = data["grid_w"] st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
t_factor = data["t_factor"] llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
grid_indices = np.indices((grid_t, grid_h, grid_w)) .view(1, -1)
if t_factor != 1.0: .expand(3, -1)
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64) + st_idx
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx) )
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
image_len = grid_t * grid_h * grid_w bos_len = 1
st_idx = int(llm_pos_ids_list[-1].max()) + 1 llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1)) + st_idx
llm_pos_ids_list.append(eos_pos) )
st = offset + 1 + image_len + 1 st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
elif modality == "video": grid_hs = video_grid_thw[:, 1]
grid_t = data["grid_t"] grid_ws = video_grid_thw[:, 2]
grid_h = data["grid_h"] t_index = (
grid_w = data["grid_w"] torch.arange(grid_t)
t_factor = data["t_factor"] * float(second_per_grids[video_idx].item())
* position_id_per_seconds
if not data["use_audio_in_video"]: )
grid_indices = np.indices((grid_t, grid_h, grid_w)) llm_pos_ids = get_llm_pos_ids_for_vision(
if t_factor != 1.0: st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64) )
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx) video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
video_len = grid_t * grid_h * grid_w st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
st_idx = int(llm_pos_ids_list[-1].max()) + 1 eos_len = 1
llm_pos_ids_list.append(
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1)) torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
llm_pos_ids_list.append(eos_pos) + st_idx
st = offset + 1 + video_len + 1 )
else: st += text_len + bos_len + video_len + eos_len
audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1)) video_idx += 1
llm_pos_ids_list.append(audio_bos_pos) remain_videos -= 1
elif (
pos_ids, _ = self._compute_interleaved_positions(st_idx, data) min_ed == ed_vision_start
llm_pos_ids_list.append(pos_ids) and ed_vision_start + 1 == ed_audio_start
st_idx = int(pos_ids.max()) + 1 and use_audio_in_video
):
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1)) text_len = min_ed - st
llm_pos_ids_list.append(eos_pos) if text_len != 0:
llm_pos_ids_list.append(eos_pos) st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
video_len = grid_t * grid_h * grid_w torch.arange(text_len, dtype=torch.long)
audio_len = self._compute_audio_token_count( .view(1, -1)
data["audio_feature_length"] .expand(3, -1)
+ st_idx
) )
st = offset + 2 + video_len + audio_len + 2 st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
video_llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(eos_block)
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < seq_len: if st < len(input_tokens):
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_len = seq_len - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
) )
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len: if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length") raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = int(llm_positions.max()) + 1 - seq_len mrope_position_delta = llm_positions.max() + 1 - seq_len
return torch.from_numpy(llm_positions), mrope_position_delta return llm_positions, mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
...@@ -2135,4 +2193,4 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -2135,4 +2193,4 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
language_model="language_model", language_model="language_model",
connector="visual.merger", connector="visual.merger",
tower_model=["visual.", "audio_tower."], tower_model=["visual.", "audio_tower."],
) )
\ No newline at end of file
...@@ -1122,7 +1122,7 @@ class Qwen3LLMModel(Qwen3Model): ...@@ -1122,7 +1122,7 @@ class Qwen3LLMModel(Qwen3Model):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -2004,7 +2004,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -2004,7 +2004,7 @@ class Qwen3VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -94,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -94,7 +94,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -476,4 +476,4 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -476,4 +476,4 @@ class Qwen3VLMoeForConditionalGeneration(
) )
# Set MoE hyperparameters # Set MoE hyperparameters
self.set_moe_parameters() self.set_moe_parameters()
\ No newline at end of file
...@@ -810,7 +810,7 @@ class QwenVLForConditionalGeneration( ...@@ -810,7 +810,7 @@ class QwenVLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -321,8 +321,8 @@ _MULTIMODAL_MODELS = { ...@@ -321,8 +321,8 @@ _MULTIMODAL_MODELS = {
), ),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"), "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
"GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501 "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": ( "GraniteSpeechForConditionalGeneration": (
"granite_speech", "granite_speech",
...@@ -476,7 +476,6 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -476,7 +476,6 @@ _SPECULATIVE_DECODING_MODELS = {
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"), "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
"GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"), "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
......
...@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module): ...@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -467,7 +467,7 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -467,7 +467,7 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -489,4 +489,4 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -489,4 +489,4 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -898,7 +898,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -898,7 +898,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -944,4 +944,4 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -944,4 +944,4 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"track_token", "track_token",
] ]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -465,7 +465,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -465,7 +465,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -481,4 +481,4 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -481,4 +481,4 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module): ...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -351,4 +351,4 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -351,4 +351,4 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -252,7 +252,7 @@ class Starcoder2Model(nn.Module): ...@@ -252,7 +252,7 @@ class Starcoder2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -336,7 +336,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -336,7 +336,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -362,4 +362,4 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -362,4 +362,4 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
["lm_head.weight"] if self.config.tie_word_embeddings else None ["lm_head.weight"] if self.config.tie_word_embeddings else None
), ),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module): ...@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -419,7 +419,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -419,7 +419,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -551,4 +551,4 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -551,4 +551,4 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
\ No newline at end of file
...@@ -1101,7 +1101,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1101,7 +1101,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1124,4 +1124,4 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1124,4 +1124,4 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
...@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None, intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -784,4 +784,4 @@ def pad_and_concat_to_dim3( ...@@ -784,4 +784,4 @@ def pad_and_concat_to_dim3(
# Pad and concatenate: # Pad and concatenate:
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features] features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
return torch.cat(features) return torch.cat(features)
\ No newline at end of file
...@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration( ...@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -899,4 +899,4 @@ class VoxtralEncoderModel(nn.Module): ...@@ -899,4 +899,4 @@ class VoxtralEncoderModel(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
return name return name
\ No newline at end of file
...@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -318,4 +318,4 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -318,4 +318,4 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}} prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict) return cast(PromptType, prompts_dict)
\ No newline at end of file
...@@ -105,7 +105,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -105,7 +105,6 @@ def create_whisper_attention_backend_with_block_pooling(
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = "WhisperCausalAttentionWithBlockPooling_" prefix = "WhisperCausalAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__( def __init__(
...@@ -152,43 +151,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -152,43 +151,6 @@ def create_whisper_attention_backend_with_block_pooling(
common_prefix_len, new_common_attn_metadata, fast_build common_prefix_len, new_common_attn_metadata, fast_build
) )
# NOTE: We need a custom impl so we can use the transformed slot_mapping
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
# the one from `forward_context.slot_mapping` (gpu_model_runner).
# This follows the same pattern as CrossAttentionImpl.
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend): if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError( raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported." f"{underlying_attn_backend} is not yet supported."
...@@ -201,7 +163,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -201,7 +163,6 @@ def create_whisper_attention_backend_with_block_pooling(
attention_backend_cls=underlying_attn_backend, attention_backend_cls=underlying_attn_backend,
overrides={ overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder, "get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
"get_kv_cache_shape": lambda num_blocks, "get_kv_cache_shape": lambda num_blocks,
block_size, block_size,
num_kv_heads, num_kv_heads,
...@@ -214,7 +175,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -214,7 +175,6 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads // block_pool_size, num_kv_heads // block_pool_size,
head_size, head_size,
), # TODO: generalize to other backends ), # TODO: generalize to other backends
"forward_includes_kv_cache_update": True,
}, },
) )
...@@ -502,4 +462,4 @@ class WhisperCausalEncoder(nn.Module): ...@@ -502,4 +462,4 @@ class WhisperCausalEncoder(nn.Module):
hidden_states = encoder_layer(hidden_states, positions) hidden_states = encoder_layer(hidden_states, positions)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment