# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Mistral adaptation of the LLaMA architecture.""" from collections.abc import Iterable import torch from torch import nn from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llama import ( LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, ) from vllm.v1.attention.backend import AttentionType from .utils import AutoWeightsLoader class MistralAttention(LlamaAttention): def __init__( self, config: LlamaConfig, hidden_size: int, num_heads: int, num_kv_heads: int, max_position_embeddings: int = 8192, quant_config: QuantizationConfig | None = None, bias: bool = False, bias_o_proj: bool = False, cache_config: CacheConfig | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: super().__init__( config=config, hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=bias, bias_o_proj=bias_o_proj, cache_config=cache_config, prefix=prefix, attn_type=attn_type, ) llama_4_scaling_config: dict[str, int | float | str] | None = getattr( config, "llama_4_scaling", None ) self.do_llama_4_scaling = llama_4_scaling_config is not None if self.do_llama_4_scaling: assert llama_4_scaling_config is not None self.llama_4_scaling_original_max_position_embeddings = ( llama_4_scaling_config["original_max_position_embeddings"] ) self.llama_4_scaling_beta = llama_4_scaling_config["beta"] def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: # Llama4 scaling scaling = 1 + self.llama_4_scaling_beta * torch.log( 1 + torch.floor( positions / self.llama_4_scaling_original_max_position_embeddings ) ) # Broadcast over head_dim return scaling.unsqueeze(-1) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) if self.do_llama_4_scaling: attn_scale = self._get_llama_4_attn_scale(positions) q = (q * attn_scale).to(q.dtype) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class MistralDecoderLayer(LlamaDecoderLayer): def __init__( self, vllm_config: VllmConfig, prefix: str = "", config: LlamaConfig | None = None, ) -> None: super().__init__( vllm_config=vllm_config, prefix=prefix, config=config, attn_layer_type=MistralAttention, ) self.layer_idx = int(prefix.split(sep=".")[-1]) quant_config = self.get_quant_config(vllm_config) config = config or vllm_config.model_config.hf_config do_fusion = getattr( quant_config, "enable_quantization_scaling_fusion", False ) and vllm_config.cache_config.cache_dtype.startswith("fp8") if do_fusion: self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj @support_torch_compile class MistralModel(LlamaModel): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", layer_type: type[nn.Module] = MistralDecoderLayer, ): super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) class MistralForCausalLM(LlamaForCausalLM): # Mistral: We don't support LoRA on the embedding layers. embedding_modules: dict[str, str] = {} # Mistral/Llama models can also be loaded with --load-format mistral # from consolidated.safetensors checkpoints mistral_mapping = { "layers": "model.layers", "attention": "self_attn", "qscale_act": "input_scale", "qscale_weight": "weight_scale", "kv_fake_quantizer.qscale_act": "kv_scale", "q_fake_quantizer.qscale_act": "attn.q_scale", "k_fake_quantizer.qscale_act": "k_scale", "v_fake_quantizer.qscale_act": "v_scale", "wq": "q_proj", "wk": "k_proj", "wv": "v_proj", "wo": "o_proj", "attention_norm": "input_layernorm", "feed_forward": "mlp", "w1": "gate_proj", "w2": "down_proj", "w3": "up_proj", "ffn_norm": "post_attention_layernorm", "tok_embeddings": "model.embed_tokens", "output": "lm_head", "norm": "model.norm", } def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", layer_type: type[nn.Module] = MistralDecoderLayer, ): super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) def _init_model( self, vllm_config: VllmConfig, prefix: str = "", layer_type: type[nn.Module] = MistralDecoderLayer, ): return MistralModel( vllm_config=vllm_config, prefix=prefix, layer_type=layer_type ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights ) def maybe_remap_mistral( self, name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: def permute(w: torch.Tensor, n_heads: int, attn_out: int): attn_in = self.config.head_dim * n_heads return ( w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) .transpose(1, 2) .reshape(attn_in, attn_out) ) mapping = self.mistral_mapping modules = name.split(".") # rotary embeds should be sliced # If using quantized model in mistral format, # quantization scales (qscale_weight) also need to be sliced if "wk" in modules and modules[-1] == "weight": loaded_weight = permute( loaded_weight, self.config.num_key_value_heads, self.config.hidden_size ) elif ( "wk" in modules and modules[-1] == "qscale_weight" and loaded_weight.numel() > 1 ): loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1) elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute( loaded_weight, self.config.num_attention_heads, self.config.hidden_size ) elif ( "wq" in modules and modules[-1] == "qscale_weight" and loaded_weight.numel() > 1 ): loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1) num_modules = len(modules) for i in range(num_modules): item = modules[i] next_item = modules[i + 1] if i < num_modules - 1 else None combined_item = f"{item}.{next_item}" if next_item is not None else None if combined_item in mapping: name = name.replace(combined_item, mapping[combined_item]) elif item in mapping and mapping[item] not in name: name = name.replace(item, mapping[item]) return name, loaded_weight