Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
......@@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
......@@ -32,6 +30,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
......@@ -61,7 +62,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -80,7 +80,6 @@ from .utils import (
is_pp_missing_parameter,
maybe_prefix,
)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
......@@ -344,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
......@@ -369,8 +362,8 @@ class KeyeSiglipAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
......@@ -408,34 +401,14 @@ class KeyeSiglipAttention(nn.Module):
prefix=f"{prefix}.out_proj",
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
......@@ -450,8 +423,7 @@ class KeyeSiglipAttention(nn.Module):
dim=-1,
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
batch_size = q.shape[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
if rope_emb is None:
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
......@@ -482,38 +454,14 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim,
)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
output, _ = self.out_proj(context_layer)
return output
......@@ -547,8 +495,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -556,8 +504,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_backend_override=attn_backend_override,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
......@@ -601,8 +549,8 @@ class KeyeSiglipEncoder(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
......@@ -614,8 +562,8 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(config.num_hidden_layers)
]
......@@ -696,8 +644,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
......@@ -707,8 +655,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
......@@ -779,16 +727,16 @@ class KeyeSiglipVisionModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.vision_model = KeyeSiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config
......@@ -1329,16 +1277,11 @@ class BaseKeyeModule(nn.Module):
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.mlp_AR = self._build_projector(
......
......@@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -259,7 +259,6 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
......
......@@ -243,7 +243,6 @@ class Llama4Attention(nn.Module):
self.rotary_emb = (
get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module):
self.rotary_emb = get_rope(
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
if (
rope_parameters is not None
and "partial_rotary_factor" not in rope_parameters
):
rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
num_heads: int,
head_dim: int,
num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32,
rope_parameters: dict | None = None,
sliding_window: int | None = None,
......@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
rope_parameters=rope_parameters,
is_neox_style=True,
......@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
max_position_embeddings = min(
config.max_position_embeddings, config.max_model_len
......@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim")
else head_dim,
num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
......
......@@ -18,15 +18,10 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV2Model,
)
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import (
_merge_multimodal_embeddings,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
from .interfaces import SupportsMultiModal
from .utils import make_empty_intermediate_tensors_factory, maybe_prefix
logger = init_logger(__name__)
......@@ -117,26 +112,10 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = super().embed_input_ids(input_ids)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
return self.model
assert is_multimodal is not None
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
def forward(
self,
......@@ -155,11 +134,3 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
"model.embed_tokens.weight",
"lm_head.weight",
}
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
......@@ -206,7 +206,6 @@ class MixtralAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module):
rope_parameters = {
"rope_type": "mllama4",
"rope_theta": config.rope_parameters["rope_theta"],
"partial_rotary_factor": 0.5,
}
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches
max_position=(config.image_size // config.patch_size) ** 2,
rope_parameters=rope_parameters,
......
......@@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module):
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
rope_parameters=rope_parameters,
dtype=torch.float16,
......
......@@ -433,7 +433,6 @@ class MolmoAttention(nn.Module):
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -199,7 +199,6 @@ class NemotronAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -102,7 +102,6 @@ class OlmoAttention(nn.Module):
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module):
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)
if multimodal_config.get_limit_per_prompt("image"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment