"vscode:/vscode.git/clone" did not exist on "5340b0e2214ec71117e7f0a953cf3033e3194d2a"
Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -747,7 +747,6 @@ class Qwen3NextAttention(nn.Module):
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters,
dual_chunk_attention_config=self.dual_chunk_attention_config,
......@@ -1092,6 +1091,8 @@ class Qwen3NextModel(nn.Module):
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
......@@ -1108,6 +1109,11 @@ class Qwen3NextModel(nn.Module):
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
logger.warning_once(
f"Parameter {name} not found in params_dict, skip loading"
)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
......
......@@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
......@@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
......@@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
......@@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
......@@ -333,9 +335,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.blocks = nn.ModuleList(
......@@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
......@@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module):
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
......@@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
)
self.quant_config = quant_config
......
......@@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
......@@ -67,12 +67,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_mrope_for_media,
compute_retained_tokens_count,
compute_retention_mask,
recompute_mrope_positions,
)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
......@@ -92,7 +99,9 @@ from .interfaces import (
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsMultiModalPruning,
SupportsPP,
_require_is_multimodal,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
......@@ -160,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
......@@ -197,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module):
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
) -> None:
super().__init__()
if norm_layer is None:
......@@ -212,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
)
self.mlp = Qwen3_VisionMLP(
dim,
......@@ -222,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module):
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
......@@ -255,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
......@@ -304,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module):
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
......@@ -317,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module):
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
# NOTE: This is used for creating empty tensor for all_gather for
......@@ -339,9 +354,9 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.merger = Qwen3_VisionPatchMerger(
......@@ -350,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.deepstack_merger_list = nn.ModuleList(
......@@ -363,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module):
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
......@@ -393,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module):
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
)
for layer_idx in range(vision_config.depth)
]
......@@ -696,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
return num_video_soft_tokens
def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
......@@ -1042,13 +1055,39 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
for curr_time in timestamps
]
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
total_retained = compute_retained_tokens_count(
tokens_per_frame,
len(frames_idx_token),
video_pruning_rate,
)
if len(frames_idx_token) == 0:
per_frame_token_counts = []
elif len(frames_idx_token) == 1:
per_frame_token_counts = [tokens_per_frame]
else:
first_frame_tokens = tokens_per_frame
remaining_tokens = max(total_retained - first_frame_tokens, 0)
base = remaining_tokens // (len(frames_idx_token) - 1)
remainder = remaining_tokens % (len(frames_idx_token) - 1)
per_frame_token_counts = [first_frame_tokens]
for frame_idx in range(1, len(frames_idx_token)):
extra = base + (1 if (frame_idx - 1) < remainder else 0)
per_frame_token_counts.append(extra)
placeholder = []
for frame_idx in frames_idx_token:
placeholder.extend(frame_idx)
for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
placeholder.extend(timestamp_tokens)
tokens_this_frame = per_frame_token_counts[
frame_idx if frame_idx < len(per_frame_token_counts) else -1
]
placeholder.extend(
[vision_start_token_id]
+ [video_token_id] * num_tokens_per_frame
+ [video_token_id] * tokens_this_frame
+ [vision_end_token_id]
)
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
......@@ -1189,6 +1228,7 @@ class Qwen3VLForConditionalGeneration(
SupportsPP,
SupportsMRoPE,
SupportsEagle3,
SupportsMultiModalPruning,
):
packed_modules_mapping = {
"qkv_proj": [
......@@ -1231,23 +1271,22 @@ class Qwen3VLForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model = Qwen3LLMForCausalLM(
......@@ -1417,6 +1456,109 @@ class Qwen3VLForConditionalGeneration(
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)
def _postprocess_image_embeds_evs(
self,
image_embeds_split: tuple[torch.Tensor, ...],
image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
"""
Append mrope positions for each for images.
This is necessary to recover correct mrope
positions after video pruning
Args:
image_embeds_split: Tuple of image embeddings for
each image item.
image_input: Image input data.
Returns:
Tuple of image embeddings for each image item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = image_embeds_out
return tuple(image_embeds_split)
def _postprocess_video_embeds_evs(
self,
video_embeds_split: tuple[torch.Tensor, ...],
video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
"""
Prunes video embeddings via Efficient Video Sampling (EVS)
and then appends mrope positions for each retained embeddings
Args:
video_embeds_split: Tuple of video embeddings for each video item.
video_input: Video input data.
Returns:
Tuple of video embeddings for each video item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
second_per_grid_ts = video_input.get("second_per_grid_ts")
if second_per_grid_ts is None:
# For Qwen3-VL, second_per_grid_ts might not be available
# Use default value of 1.0 for each video
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
else:
second_per_grid_ts = second_per_grid_ts.long()
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
video_embeds_out = []
for emb, size, video_second_per_grid_t in zip(
video_embeds_split, grid_thw_list, second_per_grid_ts
):
# For each video, we compute retention mask using EVS
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
# Debug logging for EVS pruning
logger.debug(
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
"pruning_rate=%.2f, reduction=%.1f%%)",
emb.shape[0],
retention_mask.sum().item(),
size[0],
size[1],
size[2],
self.video_pruning_rate,
(1 - retention_mask.float().mean().item()) * 100,
)
positions = compute_mrope_for_media(
size,
merge_size,
tokens_per_second=tokens_per_second,
video_second_per_grid=video_second_per_grid_t.item(),
).to(emb.device)
emb = emb[retention_mask]
positions = positions[retention_mask]
emb = torch.cat([emb, positions], dim=1)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
for input_key in kwargs:
......@@ -1439,6 +1581,20 @@ class Qwen3VLForConditionalGeneration(
def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int]]:
"""
Iterate over multimodal features and yield grid information.
For videos with EVS (Efficient Video Sampling) enabled, this function
computes the offset based on the pruned token count rather than relying
on input_tokens.index(), which would fail when tokens are pruned.
Args:
input_tokens: List of token IDs in the prompt
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_h, grid_w) for each frame/image
"""
video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
......@@ -1451,42 +1607,289 @@ class Qwen3VLForConditionalGeneration(
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
# Check if EVS (Efficient Video Sampling) is enabled
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
frame_offsets = self._extract_frame_offsets_from_mask(
mm_feature.mm_position, t
)
if frame_offsets is not None:
for rel_offset in frame_offsets:
yield offset + rel_offset, llm_grid_h, llm_grid_w
continue
# If EVS is enabled but mask is missing, this indicates a bug
# in the prompt processing pipeline. The is_embed mask should
# always be present when video_pruning_rate > 0.
raise RuntimeError(
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
"but is_embed mask is missing from mm_position. "
"This indicates a bug in prompt processing."
)
else:
# Non-EVS mode: Use original logic with input_tokens.index()
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def _get_evs_mask_segments(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[torch.Tensor] | None:
"""Extract contiguous segments from EVS is_embed mask.
The EVS (Efficient Video Sampling) mask marks which placeholder
positions should be filled with video embeddings. This method splits
the mask into contiguous segments, where each segment represents one
retained frame.
This is a pure function - it does not modify any state and always
returns the same output for the same input (idempotent).
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frame segments
Returns:
List of tensors, each containing indices for one frame segment,
or None if EVS is not enabled or validation fails.
"""
is_embed_mask = getattr(mm_position, "is_embed", None)
if is_embed_mask is None:
return None
# Find all True positions in the mask
mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
if true_indices.numel() == 0:
return None
# Split into contiguous segments (where diff > 1 indicates a gap)
if true_indices.numel() == 1:
segments = [true_indices]
else:
diffs = torch.diff(true_indices)
split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
if split_points.numel() == 0:
segments = [true_indices]
else:
segments = torch.tensor_split(
true_indices, split_points.add(1).tolist()
)
# Validate segment count matches expected frames
if len(segments) < expected_frames:
logger.debug(
"EVS mask segments (%d) do not match expected frames (%d)",
len(segments),
expected_frames,
)
return None
return segments[:expected_frames]
def _extract_frame_offsets_from_mask(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
"""Return relative offsets for each EVS-retained frame.
The prompt processor stores a boolean mask inside ``mm_position`` that
marks which placeholder locations should be populated with video
embeddings. By splitting that mask into contiguous runs we can recover
the start of every retained frame without probing ``input_tokens``.
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frames
Returns:
List of starting offsets (relative to mm_position) for each frame,
or None if EVS is not enabled.
"""
segments = self._get_evs_mask_segments(mm_position, expected_frames)
if segments is None:
return None
return [int(segment[0].item()) for segment in segments]
def _get_actual_frame_token_counts(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
"""Return actual token count for each EVS-retained frame.
This function calculates the actual number of tokens per frame by
analyzing the is_embed mask, accounting for EVS pruning. Each frame
may have a different token count due to content-aware pruning.
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frames
Returns:
List of token counts for each frame, or None if EVS is not enabled.
"""
segments = self._get_evs_mask_segments(mm_position, expected_frames)
if segments is None:
return None
return [len(seg) for seg in segments]
def recompute_mrope_positions(
self,
input_ids: list[int],
multimodal_embeddings: tuple[torch.Tensor, ...],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt (Containing
entire sequence).
multimodal_embeddings: Tuple of multimodal embeddings.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
# Device
device = (
multimodal_embeddings[0].device
if len(multimodal_embeddings)
else mrope_positions.device
)
# Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
mm_embeddings_pos = [
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
]
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,
mm_embeddings_pos,
mrope_positions,
num_computed_tokens,
vision_start_token_id,
image_token_id,
video_token_id,
)
return tuple(mm_embeddings_out), positions, mrope_positions_delta
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
# Pre-collect actual frame token counts for EVS mode
frame_token_counts_map = {}
for mm_feature in mm_features:
if mm_feature.modality == "video":
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
token_counts = self._get_actual_frame_token_counts(
mm_feature.mm_position, t
)
assert token_counts is not None, (
"EVS enabled but failed to extract frame token counts "
"from is_embed mask"
)
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
llm_pos_ids_list = []
st = 0
frame_counts_idx = {}
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features
):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
# Determine actual token count for this frame
base_offset = None
for feat_offset in frame_token_counts_map:
if offset >= feat_offset:
base_offset = feat_offset
if base_offset is not None:
# EVS mode: use actual token count from is_embed mask
assert base_offset in frame_token_counts_map, (
f"Found base_offset {base_offset} but not in frame_token_counts_map"
)
if base_offset not in frame_counts_idx:
frame_counts_idx[base_offset] = 0
counts = frame_token_counts_map[base_offset]
idx = frame_counts_idx[base_offset]
assert idx < len(counts), (
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
)
actual_frame_tokens = counts[idx]
frame_counts_idx[base_offset] += 1
else:
# Non-EVS mode (or image): use theoretical grid size
actual_frame_tokens = llm_grid_h * llm_grid_w
# Add text segment
text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(text_positions)
st_idx += text_len
# Add frame segment with actual token count (not theoretical)
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + llm_grid_h * llm_grid_w
# Only take the first actual_frame_tokens positions
frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
llm_pos_ids_list.append(frame_positions)
# Update st using actual token count
st = offset + actual_frame_tokens
# Handle final text segment
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
final_text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(final_text_positions)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
......@@ -1507,9 +1910,17 @@ class Qwen3VLForConditionalGeneration(
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
image_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings, multimodal_input
)
multimodal_embeddings += tuple(image_embeddings)
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input
)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
......@@ -1572,12 +1983,7 @@ class Qwen3VLForConditionalGeneration(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
is_multimodal = _require_is_multimodal(is_multimodal)
if self.use_deepstack:
(
......
......@@ -419,6 +419,10 @@ class Qwen3VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
......@@ -429,8 +433,8 @@ class Qwen3VLMoeForConditionalGeneration(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(
......
......@@ -264,10 +264,15 @@ _CROSS_ENCODER_MODELS = {
_MULTIMODAL_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
"AudioFlamingo3ForConditionalGeneration": (
"audioflamingo3",
"AudioFlamingo3ForConditionalGeneration",
),
"AyaVisionForConditionalGeneration": (
"aya_vision",
"AyaVisionForConditionalGeneration",
),
"BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": (
......
......@@ -161,7 +161,6 @@ class SeedOssAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters=rope_parameters,
)
......
......@@ -6,14 +6,14 @@ within a vision language model."""
from collections.abc import Iterable
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
......@@ -25,11 +25,12 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from .vision import get_vit_attn_backend
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
......@@ -147,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
return patch_embeds
# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
......@@ -190,14 +157,20 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and not current_platform.is_xpu():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and not current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
else:
apply_rotary_emb_func = apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
apply_rotary_emb_func = apply_rotary_emb.forward_native
q_embed = apply_rotary_emb_func(q, cos, sin)
k_embed = apply_rotary_emb_func(k, cos, sin)
return q_embed, k_embed
......@@ -208,6 +181,7 @@ class Siglip2Attention(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
......@@ -227,20 +201,25 @@ class Siglip2Attention(nn.Module):
self.dropout = config.attention_dropout
self.is_causal = False
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
self.tp_size = (
......@@ -249,31 +228,13 @@ class Siglip2Attention(nn.Module):
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_heads_per_partition,
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
......@@ -298,46 +259,23 @@ class Siglip2Attention(nn.Module):
keys.unsqueeze(0),
cos,
sin,
self.is_flash_attn_backend,
self.attn.is_flash_attn_backend,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
attn_output = self.flash_attn_varlen_func(
queries,
keys,
values,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
).reshape(seq_length, -1)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
outputs = []
cu = cu_seqlens.tolist()
for i in range(batch_size):
start_idx = cu[i]
end_idx = cu[i + 1]
# Each sequence is processed independently.
q_i = queries[start_idx:end_idx].unsqueeze(0)
k_i = keys[start_idx:end_idx].unsqueeze(0)
v_i = values[start_idx:end_idx].unsqueeze(0)
# (1, seq_len, num_heads, head_dim) ->
# (1, num_heads, seq_len, head_dim)
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output = self.attn(
query=queries.unsqueeze(0),
key=keys.unsqueeze(0),
value=values.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_output = attn_output.reshape(
seq_length, self.num_heads_per_partition * self.head_dim
)
attn_output, _ = self.out_proj(attn_output)
return attn_output
......@@ -347,25 +285,30 @@ class Siglip2MLP(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -380,9 +323,8 @@ class Siglip2EncoderLayer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -390,16 +332,15 @@ class Siglip2EncoderLayer(nn.Module):
self.self_attn = Siglip2Attention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
......@@ -444,9 +385,8 @@ class Siglip2Encoder(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
......@@ -455,9 +395,8 @@ class Siglip2Encoder(nn.Module):
Siglip2EncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for idx in range(config.num_hidden_layers)
]
......@@ -630,9 +569,8 @@ class Siglip2VisionTransformer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
......@@ -642,9 +580,8 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
......@@ -671,18 +608,16 @@ class Siglip2NavitModel(torch.nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
def forward(
......
......@@ -160,7 +160,6 @@ class SolarAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -148,7 +148,6 @@ class StablelmAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.config.max_position_embeddings,
rope_parameters=self.config.rope_parameters,
)
......
......@@ -112,7 +112,6 @@ class Starcoder2Attention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -196,7 +196,6 @@ class Step3TextAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embedding,
rope_parameters=rope_parameters,
)
......
......@@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import (
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
SupportsQuant,
......@@ -92,7 +94,15 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
class Base(
nn.Module,
VllmModel,
SupportsQuant,
SupportsLoRA,
SupportsPP,
SupportsEagle,
SupportsEagle3,
):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
......@@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
# Attrs for weight loading (see self.load_weights)
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
"""Ignore unexpected weights whose qualname starts with these prefixes."""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
# Attrs for Eagle3 (see self.set_aux_hidden_state_layers)
self._target_class: type[nn.Module] = nn.Module
"""Target class for Eagle3 aux hidden state recording."""
self._layer_names: dict[int, str] = {}
"""Mapping from layer index to layer name for Eagle3."""
self._output_aux_hidden_states_kwargs: dict[str, bool] = {}
"""Kwargs to pass to model forward for Eagle3 aux hidden states."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
......@@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
# Populate Eagle3 attrs
if (
isinstance(module, nn.ModuleList)
and len(module) == self.text_config.num_hidden_layers
):
self._target_class = type(child_module)
layer_name = qual_name.removeprefix("model.")
self._layer_names[int(child_name)] = layer_name
# Replace modules as needed
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
......@@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
else:
position_ids = positions[None, ...]
hidden_states = self.model(
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False,
**self._output_aux_hidden_states_kwargs,
**kwargs,
)[0][0, ...] # we remove batch dimension for now
)
# We must remove the batch dimension from these outputs
hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(
......@@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
f"Transformers modeling backend requires transformers>={required} "
f"for {feature}, but got {installed}"
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.check_version("5.0.0.dev0", "Eagle3 support")
from transformers.utils.generic import OutputRecorder
# The default value in PreTrainedModel is None
if self.model._can_record_outputs is None:
self.model._can_record_outputs = {}
target_class = self._target_class
for layer in layers:
# layer - 1 because we want the input to the layer
layer_name = self._layer_names[layer - 1]
layer_key = f"aux_hidden_state_{layer}"
aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name)
self.model._can_record_outputs[layer_key] = aux_hidden_state_i
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3)
......@@ -11,7 +11,7 @@ import torch
from transformers import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -88,14 +88,11 @@ def get_vit_attn_backend(
"""
Get the available attention backend for Vision Transformer.
"""
if attn_backend_override is not None:
return attn_backend_override
selected_backend = get_current_vllm_config().attention_config.backend
if selected_backend is not None:
return selected_backend
return current_platform.get_vit_attn_backend(head_size, dtype)
return current_platform.get_vit_attn_backend(
head_size,
dtype,
backend=attn_backend_override,
)
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
......
......@@ -51,7 +51,8 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix
......
......@@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = []
input_is_batched = False
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
......@@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
embeds.dtype
)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
if input_is_batched:
# Models using WhisperEncoder may handle batching internally.
hidden_states = torch.cat(hidden_states)
else:
hidden_states = torch.stack(hidden_states, dim=0)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
......@@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
assert len(encoder_outputs) in (0, 1)
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
......@@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return [self.model.get_encoder_outputs(audio_input["input_features"])]
# Split concatenated encoder outputs into one tensor per audio input
enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
# The assumption is we can only process whole mm items (audios)
return enc_output.unbind(dim=0)
def embed_input_ids(
self,
......
......@@ -230,7 +230,6 @@ class Zamba2Attention(nn.Module):
if config.use_mem_rope:
self.rotary_emb = get_rope(
head_size=self.attention_head_dim,
rotary_dim=self.attention_head_dim,
max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -89,7 +89,7 @@ def _extract_data_from_linear_base_module(
assert m.quant_method.quant_config is not None
w = m.weight
ws = m.weight_scale
ws = m.weight_scale_inv if hasattr(m, "weight_scale_inv") else m.weight_scale
quant_block_size = m.quant_method.quant_config.weight_block_size
assert isinstance(w, torch.Tensor)
......
......@@ -127,13 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()
def encode_base64(self, media: torch.Tensor) -> str:
return tensor2base64(media)
......@@ -122,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()
def encode_base64(self, media: torch.Tensor) -> str:
return pybase64.b64encode(media.numpy()).decode("utf-8")
......@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from functools import cached_property, partial
from itertools import accumulate
from typing import (
TYPE_CHECKING,
......@@ -169,11 +169,42 @@ class PlaceholderRange:
between `offset` and `offset + length` to assign embeddings to.
"""
def get_num_embeds(self) -> int:
@cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
if self.is_embed is None:
return None
return self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int:
if self.embeds_cumsum is None:
return self.length
return int(self.is_embed.sum().item())
return int(self.embeds_cumsum[-1])
def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int
) -> tuple[int, int]:
"""
Returns the starting and ending indices of the embeddings of encoder outputs
in the range of [start_idx, end_idx) in the placeholders.
For example, given:
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
the second and the third embeddings from the encoder output.
"""
if self.embeds_cumsum is None:
return start_idx, end_idx
embeds_start_idx = (
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
return embeds_start_idx, embeds_end_idx
def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.
......@@ -188,7 +219,7 @@ class PlaceholderRange:
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]
return [(self.offset, self.offset + self.length - 1)]
mask_i = self.is_embed.int()
starts = torch.nonzero(
......@@ -954,7 +985,7 @@ MultiModalKwargsOptionalItems: TypeAlias = (
)
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.")
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.14.")
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
......@@ -964,7 +995,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
@staticmethod
@deprecated(
"`MultiModalKwargs.from_hf_inputs` is deprecated and "
"will be removed in v0.13. "
"will be removed in v0.14. "
"Please use `MultiModalKwargsItems.from_hf_inputs` and "
"access the tensor data using `.get_data()`."
)
......@@ -977,7 +1008,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
@staticmethod
@deprecated(
"`MultiModalKwargs.from_items` is deprecated and "
"will be removed in v0.13. "
"will be removed in v0.14. "
"Please use `MultiModalKwargsItems.from_seq` and "
"access the tensor data using `.get_data()`."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment