Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -34,7 +34,7 @@ import torch.nn.functional as F ...@@ -34,7 +34,7 @@ import torch.nn.functional as F
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
...@@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module): ...@@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
) )
self.scale = self.hidden_size_per_attention_head**-0.5 self.scale = self.hidden_size_per_attention_head**-0.5
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.scale, self.scale,
......
...@@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import ( ...@@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig, Idefics2VisionConfig,
) )
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module): ...@@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Use unified MultiHeadAttention with Flash Attention support # Use unified MMEncoderAttention with Flash Attention support
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )
...@@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module): ...@@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1) query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation # Use unified MMEncoderAttention implementation
out = self.attn(query_states, key_states, value_states) out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output
......
...@@ -15,7 +15,7 @@ import torch.nn as nn ...@@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module): ...@@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )
......
...@@ -14,7 +14,7 @@ import torch.nn as nn ...@@ -14,7 +14,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module): ...@@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x shape: (B, N, C)""" """x shape: (B, N, C)"""
...@@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module): ...@@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
k = self.k_norm(k) k = self.k_norm(k)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
x = self.attn(q, k, v) x = self.attn(q, k, v)
x = self.projection_layer(x) x = self.projection_layer(x)
......
...@@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import ( ...@@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit, get_best_fit,
) )
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module): ...@@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_local_heads, self.head_dim, self.scaling self.num_local_heads, self.head_dim, self.scaling
) )
......
...@@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT ...@@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module): ...@@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
) )
......
...@@ -169,10 +169,13 @@ class DeciLMDecoderLayer(nn.Module): ...@@ -169,10 +169,13 @@ class DeciLMDecoderLayer(nn.Module):
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if not self._is_no_op_ffn: if not self._is_no_op_ffn:
ffn_mult = block_config.ffn.ffn_mult if hasattr(block_config.ffn, "ffn_mult"):
intermediate_size = _ffn_mult_to_intermediate_size( ffn_mult = block_config.ffn.ffn_mult
ffn_mult, config.hidden_size intermediate_size = _ffn_mult_to_intermediate_size(
) ffn_mult, config.hidden_size
)
else:
intermediate_size = block_config.ffn.intermediate_size
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
...@@ -70,7 +70,6 @@ from vllm.multimodal.inputs import ( ...@@ -70,7 +70,6 @@ from vllm.multimodal.inputs import (
MultiModalFeatureSpec, MultiModalFeatureSpec,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
NestedTensors,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
AudioProcessorItems, AudioProcessorItems,
...@@ -1150,27 +1149,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1150,27 +1149,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
def embed_multimodal_v0(self, **kwargs: object) -> NestedTensors | None:
audio_input = self._parse_and_validate_audio_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if audio_input is None and image_input is None and video_input is None:
return None
multimodal_embeddings: list[tuple[NestedTensors, str]] = []
if audio_input is not None:
audio_embeds = self._process_audio_input(audio_input)
multimodal_embeddings.append((audio_embeds, "audio"))
if image_input is not None:
image_embeds = self._process_image_input(image_input)
multimodal_embeddings.append((image_embeds, "image"))
if video_input is not None:
video_embeds = self._process_video_input(video_input)
multimodal_embeddings.append((video_embeds, "video"))
return multimodal_embeddings
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module): ...@@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
torch.arange(config.max_position_embeddings).unsqueeze(0), torch.arange(config.max_position_embeddings).unsqueeze(0),
) )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model( def _build_model(
self, vllm_config: VllmConfig, prefix: str = "" self, vllm_config: VllmConfig, prefix: str = ""
) -> BertModel | BertWithRope: ) -> BertModel | BertWithRope:
if vllm_config.model_config.hf_config.position_embedding_type == "rotary": hf_config = vllm_config.model_config.hf_config
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) kwargs = dict(vllm_config=vllm_config, prefix=prefix)
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
else: else:
return BertModel( return JinaRobertaModel(**kwargs)
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights) weights_list = list(weights)
......
...@@ -16,8 +16,8 @@ from transformers import ( ...@@ -16,8 +16,8 @@ from transformers import (
SiglipVisionConfig, SiglipVisionConfig,
) )
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
...@@ -379,7 +379,7 @@ class SiglipAttention(nn.Module): ...@@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module): ...@@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention, attn_cls=MMEncoderAttention,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
......
...@@ -15,7 +15,7 @@ from torchvision import transforms ...@@ -15,7 +15,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module): ...@@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward( def forward(
self, self,
...@@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module): ...@@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
......
...@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module): ...@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
...@@ -201,12 +200,9 @@ class SwinAttention(nn.Module): ...@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(hidden_states, attention_mask, output_attentions)
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -339,18 +335,14 @@ class SwinStage(nn.Module): ...@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
height, width = input_dimensions height, width = input_dimensions
for i, layer_module in enumerate(self.blocks): for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
...@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module): ...@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers): for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
...@@ -473,7 +461,6 @@ class SwinModel(nn.Module): ...@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor | None = None, pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None, output_attentions: bool | None = None,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values) embedding_output, input_dimensions = self.embeddings(pixel_values)
...@@ -481,7 +468,6 @@ class SwinModel(nn.Module): ...@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
input_dimensions, input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import copy import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin): ...@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
) )
hidden_states = hidden_states + positions hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers: for layer in self.layers:
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
......
...@@ -16,9 +16,9 @@ from transformers import ( ...@@ -16,9 +16,9 @@ from transformers import (
) )
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema): ...@@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
] ]
class WhisperEncoderAttention(MultiHeadAttention): class WhisperEncoderAttention(MMEncoderAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support.""" """Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward( def forward(
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
...@@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ...@@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int): def _get_fp8_gemm_nt_m_values(w: torch.Tensor, max_tokens: int) -> list[int]:
"""Get the M values to warmup for a given weight tensor."""
n, _ = w.size()
device = w.device
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
return _generate_optimal_warmup_m_values(max_tokens, n, device)
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
return list(range(1, max_tokens + 1))
def _deepgemm_fp8_gemm_nt_warmup(
w: torch.Tensor,
ws: torch.Tensor,
max_tokens: int,
pbar: tqdm | None = None,
):
if w.size() in FP8_GEMM_NT_WARMUP_CACHE: if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
return return
...@@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: ...@@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
) )
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". m_values = _get_fp8_gemm_nt_m_values(w, max_tokens)
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
m_values = list(range(1, max_tokens + 1))
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
pbar = tqdm(total=len(m_values), desc=desc)
for num_tokens in m_values: for num_tokens in m_values:
fp8_gemm_nt( fp8_gemm_nt(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
) )
pbar.update(1) if pbar is not None:
pbar.update(1)
FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
...@@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: ...@@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( def _get_grouped_gemm_params(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int, num_topk: int,
max_tokens: int, max_tokens: int,
): ) -> tuple[int, int, torch.Tensor]:
if (
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
block_m = get_mk_alignment_for_contiguous_layout()[0] block_m = get_mk_alignment_for_contiguous_layout()[0]
...@@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ...@@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
) )
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
return MAX_M, block_m, expert_ids
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
max_tokens: int,
pbar: tqdm | None = None,
):
if (
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
MAX_M, block_m, expert_ids = _get_grouped_gemm_params(w1, w2, num_topk, max_tokens)
device = w1.device
def _warmup(w: torch.Tensor, w_scale: torch.Tensor): def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size() _, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
...@@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ...@@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
) )
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
# Generate M values in block_m increments (already optimized for MoE)
m_values = list(range(block_m, MAX_M + 1, block_m)) m_values = list(range(block_m, MAX_M + 1, block_m))
pbar = tqdm(
total=len(m_values),
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
f"[{len(m_values)} values, block_m={block_m}]",
)
for num_tokens in m_values: for num_tokens in m_values:
m_grouped_fp8_gemm_nt_contiguous( m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (a1q[:num_tokens], a1q_scales[:num_tokens]),
...@@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ...@@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
out[:num_tokens], out[:num_tokens],
expert_ids[:num_tokens], expert_ids[:num_tokens],
) )
pbar.update(1) if pbar is not None:
pbar.update(1)
for w, ws in [(w1, w1_scale), (w2, w2_scale)]: for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
...@@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ...@@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size()) GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): def deepgemm_fp8_gemm_nt_warmup(
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
):
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)] dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
for dgm in dg_modules: for dgm in dg_modules:
w, ws, _ = _extract_data_from_linear_base_module(dgm) w, ws, _ = _extract_data_from_linear_base_module(dgm)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens, pbar=pbar)
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
model: torch.nn.Module, max_tokens: int model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
): ):
dg_modules = [ dg_modules = [
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
...@@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( ...@@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
dgm dgm
) )
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w13, w2, w13_scale, w2_scale, num_topk, max_tokens w13, w2, w13_scale, w2_scale, num_topk, max_tokens, pbar=pbar
) )
def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
seen_fp8_sizes: set[torch.Size] = set(FP8_GEMM_NT_WARMUP_CACHE)
seen_grouped_sizes: set[torch.Size] = set(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
)
total = 0
for m in model.modules():
if _fp8_linear_may_use_deep_gemm(m):
w, _, _ = _extract_data_from_linear_base_module(m)
if w.size() not in seen_fp8_sizes:
total += len(_get_fp8_gemm_nt_m_values(w, max_tokens))
seen_fp8_sizes.add(w.size())
elif _fused_moe_grouped_gemm_may_use_deep_gemm(m):
w13, _, w2, _, num_topk = _extract_data_from_fused_moe_module(m)
if w13.size() in seen_grouped_sizes and w2.size() in seen_grouped_sizes:
continue
MAX_M, block_m, _ = _get_grouped_gemm_params(w13, w2, num_topk, max_tokens)
n_values = (MAX_M - block_m) // block_m + 1
if w13.size() not in seen_grouped_sizes:
total += n_values
seen_grouped_sizes.add(w13.size())
if w2.size() not in seen_grouped_sizes:
total += n_values
seen_grouped_sizes.add(w2.size())
return total
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
deepgemm_fp8_gemm_nt_warmup(model, max_tokens) total = _count_warmup_iterations(model, max_tokens)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens) if total == 0:
return
# Only show progress bar on rank 0 to avoid cluttered output
if is_global_first_rank():
with tqdm(total=total, desc="DeepGEMM warmup") as pbar:
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, pbar)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, pbar)
else:
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, None)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, None)
...@@ -49,13 +49,12 @@ def kernel_warmup(worker: "Worker"): ...@@ -49,13 +49,12 @@ def kernel_warmup(worker: "Worker"):
except NotImplementedError: except NotImplementedError:
return False return False
# NOTE: we add check for empty attn_groups to avoid errors when
# deploying models such as E instances and encoder-only models.
# As for those models, worker.model_runner.attn_groups is empty.
# This change is made during EPD feature development.
if ( if (
not worker.model_runner.is_pooling_model not worker.model_runner.is_pooling_model
and worker.model_runner.attn_groups and worker.model_runner.attn_groups
# NOTE: This should be `any` instead of `all` but other hybrid attention
# backends don't support this dummy run. Once we remove
# `build_for_cudagraph_capture`, we can change it to `any`.
and all( and all(
_is_flashinfer_backend(group.backend) _is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups for groups in worker.model_runner.attn_groups
......
...@@ -124,11 +124,9 @@ def use_rocm_custom_paged_attention( ...@@ -124,11 +124,9 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None, sinks: torch.Tensor | None = None,
) -> bool: ) -> bool:
from vllm._aiter_ops import rocm_aiter_ops GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
# GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
# ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
...@@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention( ...@@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
# and (gqa_ratio >= 1 and gqa_ratio <= 16) # and (gqa_ratio >= 1 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024 # and max_seq_len <= 128 * 1024
# and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) # and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
# and not (rocm_aiter_ops.is_pa_attn_enabled())
# and sinks is None # and sinks is None
# ) # )
......
...@@ -162,7 +162,10 @@ class XPUPlatform(Platform): ...@@ -162,7 +162,10 @@ class XPUPlatform(Platform):
# check and update parallel config # check and update parallel config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" # Only override worker_cls if it's still the default "auto"
# This allows custom workers (like vllm-omni workers) to be used on XPU
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
if vllm_config.kv_transfer_config is not None: if vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.enable_permute_local_kv = True vllm_config.kv_transfer_config.enable_permute_local_kv = True
......
...@@ -26,6 +26,8 @@ class DeepSeekV3ReasoningParser(ReasoningParser): ...@@ -26,6 +26,8 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {} chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
thinking = bool(chat_kwargs.pop("thinking", False)) thinking = bool(chat_kwargs.pop("thinking", False))
enable_thinking = bool(chat_kwargs.pop("enable_thinking", False))
thinking = thinking or enable_thinking
if thinking: if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
......
...@@ -50,6 +50,8 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): ...@@ -50,6 +50,8 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
**kwargs, **kwargs,
) -> str | list[int]: ) -> str | list[int]:
thinking = kwargs.get("thinking", False) thinking = kwargs.get("thinking", False)
enable_thinking = kwargs.get("enable_thinking", False)
thinking = thinking or enable_thinking
thinking_mode = "thinking" thinking_mode = "thinking"
if not thinking: if not thinking:
thinking_mode = "chat" thinking_mode = "chat"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment