Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
...@@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar ...@@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
...@@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature ...@@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layers.mm_encoder_attention import (
from vllm.attention.layer import ( MMEncoderAttention,
maybe_get_vit_flash_attn_backend,
) )
from vllm.config import VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -32,6 +30,9 @@ from vllm.model_executor.layers.linear import ( ...@@ -32,6 +30,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
...@@ -61,7 +62,6 @@ from vllm.multimodal.processing import ( ...@@ -61,7 +62,6 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -80,7 +80,6 @@ from .utils import ( ...@@ -80,7 +80,6 @@ from .utils import (
is_pp_missing_parameter, is_pp_missing_parameter,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -344,20 +343,14 @@ def apply_rotary_pos_emb_flashatt( ...@@ -344,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
if current_platform.is_cuda(): apply_rotary_emb = ApplyRotaryEmb(
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb enforce_enable=True,
elif current_platform.is_rocm(): enable_fp32_compute=True,
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb )
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True) q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed return q_embed, k_embed
...@@ -369,8 +362,8 @@ class KeyeSiglipAttention(nn.Module): ...@@ -369,8 +362,8 @@ class KeyeSiglipAttention(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -408,34 +401,14 @@ class KeyeSiglipAttention(nn.Module): ...@@ -408,34 +401,14 @@ class KeyeSiglipAttention(nn.Module):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
# Detect attention implementation. self.attn = MMEncoderAttention(
self.attn_backend = get_vit_attn_backend( num_heads=self.num_heads,
head_size=self.head_dim, head_size=self.head_dim,
dtype=torch.get_default_dtype(), num_kv_heads=self.num_kv_heads,
attn_backend_override=attn_backend_override, prefix=f"{prefix}.attn",
) multimodal_config=multimodal_config,
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
) )
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -450,8 +423,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -450,8 +423,7 @@ class KeyeSiglipAttention(nn.Module):
dim=-1, dim=-1,
) )
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
batch_size = q.shape[0]
if rope_emb is None: if rope_emb is None:
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
...@@ -482,38 +454,14 @@ class KeyeSiglipAttention(nn.Module): ...@@ -482,38 +454,14 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim, self.head_dim,
) )
if self.is_flash_attn_backend: context_layer = self.attn(
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) query=q,
key=k,
output = self.flash_attn_varlen_func( value=v,
q, cu_seqlens=cu_seqlens,
k, max_seqlen=max_seqlen,
v, )
cu_seqlens_q=cu_seqlens, context_layer = rearrange(context_layer, "b s h d -> b s (h d)")
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
output, _ = self.out_proj(context_layer) output, _ = self.out_proj(context_layer)
return output return output
...@@ -547,8 +495,8 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -547,8 +495,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -556,8 +504,8 @@ class KeyeSiglipEncoderLayer(nn.Module): ...@@ -556,8 +504,8 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.self_attn = KeyeSiglipAttention( self.self_attn = KeyeSiglipAttention(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_backend_override=attn_backend_override,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -601,8 +549,8 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -601,8 +549,8 @@ class KeyeSiglipEncoder(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -614,8 +562,8 @@ class KeyeSiglipEncoder(nn.Module): ...@@ -614,8 +562,8 @@ class KeyeSiglipEncoder(nn.Module):
KeyeSiglipEncoderLayer( KeyeSiglipEncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]
...@@ -696,8 +644,8 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -696,8 +644,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -707,8 +655,8 @@ class KeyeSiglipVisionTransformer(nn.Module): ...@@ -707,8 +655,8 @@ class KeyeSiglipVisionTransformer(nn.Module):
self.encoder = KeyeSiglipEncoder( self.encoder = KeyeSiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
...@@ -779,16 +727,16 @@ class KeyeSiglipVisionModel(nn.Module): ...@@ -779,16 +727,16 @@ class KeyeSiglipVisionModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.vision_model = KeyeSiglipVisionTransformer( self.vision_model = KeyeSiglipVisionTransformer(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
) )
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1329,16 +1277,11 @@ class BaseKeyeModule(nn.Module): ...@@ -1329,16 +1277,11 @@ class BaseKeyeModule(nn.Module):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = KeyeSiglipVisionModel( self.visual = KeyeSiglipVisionModel(
config.vision_config, config.vision_config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
) )
self.mlp_AR = self._build_projector( self.mlp_AR = self._build_projector(
......
...@@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module): ...@@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module): ...@@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -259,7 +259,6 @@ class LlamaAttention(nn.Module): ...@@ -259,7 +259,6 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None), rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -243,7 +243,6 @@ class Llama4Attention(nn.Module): ...@@ -243,7 +243,6 @@ class Llama4Attention(nn.Module):
self.rotary_emb = ( self.rotary_emb = (
get_rope( get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module): ...@@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module): ...@@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.qk_rope_head_dim, self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module): ...@@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
if (
rope_parameters is not None
and "partial_rotary_factor" not in rope_parameters
):
rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module): ...@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
num_kv_heads: int, num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_parameters: dict | None = None, rope_parameters: dict | None = None,
sliding_window: int | None = None, sliding_window: int | None = None,
...@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module): ...@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,
...@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
head_dim = getattr(config, "head_dim", None) head_dim = getattr(config, "head_dim", None)
if head_dim is None: if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
max_position_embeddings = min( max_position_embeddings = min(
config.max_position_embeddings, config.max_model_len config.max_position_embeddings, config.max_model_len
...@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
head_dim=head_dim, head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim")
else head_dim,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
......
...@@ -18,15 +18,10 @@ from vllm.model_executor.models.deepseek_v2 import ( ...@@ -18,15 +18,10 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2Model, DeepseekV2Model,
) )
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import ( from .interfaces import SupportsMultiModal
_merge_multimodal_embeddings, from .utils import make_empty_intermediate_tensors_factory, maybe_prefix
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -117,26 +112,10 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM): ...@@ -117,26 +112,10 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
) )
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings( def get_language_model(self) -> torch.nn.Module:
self, return self.model
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = super().embed_input_ids(input_ids)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
assert is_multimodal is not None embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward( def forward(
self, self,
...@@ -155,11 +134,3 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM): ...@@ -155,11 +134,3 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
"model.embed_tokens.weight", "model.embed_tokens.weight",
"lm_head.weight", "lm_head.weight",
} }
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
...@@ -206,7 +206,6 @@ class MixtralAttention(nn.Module): ...@@ -206,7 +206,6 @@ class MixtralAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module): ...@@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module):
rope_parameters = { rope_parameters = {
"rope_type": "mllama4", "rope_type": "mllama4",
"rope_theta": config.rope_parameters["rope_theta"], "rope_theta": config.rope_parameters["rope_theta"],
"partial_rotary_factor": 0.5,
} }
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches # number of image patches
max_position=(config.image_size // config.patch_size) ** 2, max_position=(config.image_size // config.patch_size) ** 2,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
......
...@@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module): ...@@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dtype=torch.float16, dtype=torch.float16,
......
...@@ -433,7 +433,6 @@ class MolmoAttention(nn.Module): ...@@ -433,7 +433,6 @@ class MolmoAttention(nn.Module):
# Rotary embeddings. # Rotary embeddings.
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -199,7 +199,6 @@ class NemotronAttention(nn.Module): ...@@ -199,7 +199,6 @@ class NemotronAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention): ...@@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -102,7 +102,6 @@ class OlmoAttention(nn.Module): ...@@ -102,7 +102,6 @@ class OlmoAttention(nn.Module):
# Rotary embeddings. # Rotary embeddings.
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module): ...@@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module):
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module): ...@@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
) )
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = OpenCUAVisionTransformer( self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config, vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config, quant_config=self.quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
) )
else: else:
self.visual = None self.visual = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment