# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from typing import Annotated, Literal import torch import torch.nn as nn import torch.nn.functional as F from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from transformers import PixtralVisionConfig from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid, ) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, NestedTensors, ) from vllm.multimodal.parse import ( ImageProcessorItems, ImageSize, MultiModalDataItems, ) from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing.processor import ( BaseMultiModalProcessor, BaseProcessingInfo, MultiModalProcessingInfo, ProcessorInputs, PromptReplacement, PromptUpdate, PromptUpdateDetails, TimingContext, ) from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsPP, supports_eagle3, ) from .module_mapping import MultiModelKeys from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix from .vision import ( VisionEncoderInfo, VisionFeatureSelectStrategy, is_vit_use_data_parallel, resolve_visual_encoder_outputs, ) try: # Note: vLLM does not install xformers by default. from xformers import ops as xops if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: USE_XFORMERS_OPS = True except ImportError: USE_XFORMERS_OPS = False PATCH_MERGE = "patch_merge" def _is_layer_none_or_staged(layer: nn.Module) -> bool: return layer is None or isinstance(layer, StageMissingLayer) class PixtralImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - c: Number of channels (3) - h: Height of each image - w: Width of each image The result of stacking `ImageEncoding.tokens` from each prompt. """ type: Literal["pixel_values"] = "pixel_values" images: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), ] class PixtralProcessingInfo(BaseProcessingInfo): def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) if not isinstance(tokenizer, MistralTokenizer): raise ValueError("This model requires `--tokenizer-mode mistral`") return tokenizer def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor: return self.ctx.init_processor( MistralCommonPixtralProcessor, tokenizer=self.get_tokenizer(), **kwargs, ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_hf_processor().image_processor max_image_size = image_processor.mm_encoder.mm_config.max_image_size return ImageSize(width=max_image_size, height=max_image_size) class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") return { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ) } def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], mm_data: MultiModalDataDict | None = None, ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = ( self.get_dummy_mm_data(seq_len, mm_counts, mm_options) if mm_data is None else mm_data ) dummy_mm_items = self.info.parse_mm_data(dummy_mm_data) dummy_images = ( [] if "image" not in dummy_mm_data else dummy_mm_items["image"].get_all() ) request = ChatCompletionRequest( messages=[ UserMessage( content=[ TextChunk(text=dummy_text), *(ImageChunk(image=image) for image in dummy_images), ] ), ] ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items) class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict(images=MultiModalFieldConfig.batched("image")) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_break_id = processor.image_break_id image_token_id = processor.image_token_id image_end_id = processor.image_end_id def get_replacement(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) _, nrows, ncols = processor.image_processor.get_number_of_image_patches( image_size.height, image_size.width, ) tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id return PromptUpdateDetails.select_token_id(tokens, image_token_id) return [ PromptReplacement( modality="image", target="", # Never match the prompt (see below note) replacement=get_replacement, ), ] def _cached_apply_hf_processor( self, inputs: ProcessorInputs, timing_ctx: TimingContext, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx) # NOTE: The tokens are already inserted by the chat template return prompt_ids, mm_info, True @MULTIMODAL_REGISTRY.register_processor( PixtralMultiModalProcessor, info=PixtralProcessingInfo, dummy_inputs=PixtralDummyInputsBuilder, ) class PixtralForConditionalGeneration( nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP ): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} vision_args = { key: value for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields } self.vision_args = VisionEncoderArgs(**vision_args) # init MistralForCausalLM with self._mark_language_model(vllm_config): self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) with self._mark_tower_model(vllm_config, "image"): self.vision_encoder = VisionTransformer(self.vision_args) self.pre_mm_projector_norm = ( RMSNorm(self.vision_args.hidden_size, eps=1e-5) if self.vision_args.add_pre_mm_projector_layer_norm else None ) self.patch_merger = ( PatchMerger( vision_encoder_dim=self.vision_args.hidden_size, spatial_merge_size=self.vision_args.spatial_merge_size, use_mlp_bias=False, ) if self.vision_args.mm_projector_id == PATCH_MERGE else None ) self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) def _parse_and_validate_image_input( self, **kwargs: object ) -> PixtralImagePixelInputs | None: images = kwargs.pop("images", None) if images is None: return None return PixtralImagePixelInputs( type="pixel_values", images=images, ) def _process_image_input( self, image_input: PixtralImagePixelInputs, ) -> tuple[torch.Tensor, ...]: images = image_input["images"] image_features = self.vision_encoder(images) feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_features = torch.cat(image_features) if self.pre_mm_projector_norm is not None: image_features = self.pre_mm_projector_norm(image_features) if self.patch_merger is not None: patch_size = self.vision_args.patch_size spatial_merge_size_square = self.vision_args.spatial_merge_size**2 img_patch_dims = [ (img.shape[1] // patch_size, img.shape[2] // patch_size) for img in images ] feature_sizes = [ feature_size // spatial_merge_size_square for feature_size in feature_sizes ] image_features = self.patch_merger( image_features, image_sizes=img_patch_dims ) image_embeds = self.vision_language_adapter(image_features) image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: """Run forward pass for pixtral.""" if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def _require_language_model_eagle3(self) -> None: if not supports_eagle3(self.language_model): raise RuntimeError( f"EAGLE-3 speculative decoding requires the language model to " f"support EAGLE-3, but {type(self.language_model).__name__} does not." ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self._require_language_model_eagle3() self.language_model.set_aux_hidden_state_layers(layers) def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: self._require_language_model_eagle3() return self.language_model.get_eagle3_aux_hidden_state_layers() def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith(("vision_encoder", "vision_tower")) def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith( ("vision_language_adapter", "multi_modal_projector") ) def is_patch_merger(weight: tuple[str, torch.Tensor]): return weight[0].startswith("patch_merger") def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") # Get references to parameters for direct loading vision_encoder_dict = ( dict(self.vision_encoder.named_parameters()) if self.vision_encoder is not None else {} ) patch_merger_dict = ( dict(self.patch_merger.named_parameters()) if self.patch_merger is not None else {} ) pre_mm_projector_norm_dict = ( dict(self.pre_mm_projector_norm.named_parameters()) if self.pre_mm_projector_norm is not None else {} ) vision_lang_adapter_dict = ( dict(self.vision_language_adapter.named_parameters()) if self.vision_language_adapter is not None else {} ) def llm_weights_generator(): # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): if _is_layer_none_or_staged(self.vision_encoder): continue # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_encoder_dict.get(trimmed_name) if param is not None: with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): if _is_layer_none_or_staged(self.patch_merger): continue # Load vision patch merger weights directly trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): if _is_layer_none_or_staged(self.pre_mm_projector_norm): continue # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): if _is_layer_none_or_staged(self.vision_language_adapter): continue # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict.get(trimmed_name) if param is not None: with torch.no_grad(): default_weight_loader(param, w) else: # LLM weights: yield them to be loaded # by language_model.load_weights # Strip "language_model." prefix if present (HF sharded format) name = name.removeprefix("language_model.") yield (name, w) # Now we call the language model load with the generator self.language_model.load_weights(llm_weights_generator()) def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="vision_language_adapter", tower_model="vision_encoder", ) def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: if getattr(self, "patch_merger", None) is None: return num_image_tokens merge_size = self.vision_args.spatial_merge_size return num_image_tokens * (merge_size**2) def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int: if getattr(self, "patch_merger", None) is None: return num_vision_tokens merge_size = self.vision_args.spatial_merge_size return num_vision_tokens // (merge_size**2) # Vision encoder @dataclass class VisionEncoderArgs: hidden_size: int num_channels: int image_size: int patch_size: int intermediate_size: int num_hidden_layers: int num_attention_heads: int rope_theta: float # for rope-2D image_token_id: int adapter_bias: bool = True spatial_merge_size: int = 1 add_pre_mm_projector_layer_norm: bool = False mm_projector_id: str = "" def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2) """ ndim = x.ndim assert ndim > 1 assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( freqs_cis.shape, (x.shape[1], x.shape[-1]), ) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def precompute_freqs_cis_2d( dim: int, height: int, width: int, theta: float, ) -> torch.Tensor: """ freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by (height, width) position tuples """ # (dim / 2) frequency bases freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height, device=freqs.device) w = torch.arange(width, device=freqs.device) freqs_h = torch.outer(h, freqs[::2]).float() freqs_w = torch.outer(w, freqs[1::2]).float() freqs_2d = torch.cat( [ freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1), ], dim=-1, ) return torch.polar(torch.ones_like(freqs_2d), freqs_2d) def apply_rotary_emb_vit( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) assert freqs_cis.dtype == torch.complex64 freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class FeedForward(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() assert args.intermediate_size is not None self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args assert not args.hidden_size % args.num_attention_heads self.n_heads = args.num_attention_heads self.head_dim = args.hidden_size // args.num_attention_heads self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) def forward( self, x: torch.Tensor, mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: batch, patches, _ = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.reshape(batch, patches, self.n_heads, self.head_dim) k = k.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) if USE_XFORMERS_OPS: out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) else: q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.wo(out) class TransformerBlock(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.attention = Attention(args) self.feed_forward = FeedForward(args) self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) def forward( self, x: torch.Tensor, mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: r = self.attention.forward( self.attention_norm(x), mask=mask, freqs_cis=freqs_cis ) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r return out class Transformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.layers = torch.nn.ModuleList() for _ in range(args.num_hidden_layers): self.layers.append(TransformerBlock(args)) def forward( self, x: torch.Tensor, mask: torch.Tensor, freqs_cis: torch.Tensor | None, ) -> torch.Tensor: for layer in self.layers: x = layer(x, mask=mask, freqs_cis=freqs_cis) return x def position_meshgrid( patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: positions = torch.cat( [ torch.stack( torch.meshgrid( torch.arange(p.shape[-2]), torch.arange(p.shape[-1]), indexing="ij", ), dim=-1, ).reshape(-1, 2) for p in patch_embeds_list ] ) return positions class VisionTransformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args self.patch_conv = Conv2dLayer( in_channels=args.num_channels, out_channels=args.hidden_size, kernel_size=args.patch_size, stride=args.patch_size, bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) self.transformer = Transformer(args) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" self._freqs_cis: torch.Tensor | None = None @property def max_patches_per_side(self) -> int: return self.args.image_size // self.args.patch_size @property def device(self) -> torch.types.Device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @property def freqs_cis(self) -> torch.Tensor: if self._freqs_cis is None: self._freqs_cis = precompute_freqs_cis_2d( dim=self.args.hidden_size // self.args.num_attention_heads, height=self.max_patches_per_side, width=self.max_patches_per_side, theta=self.args.rope_theta, ) if self._freqs_cis.device != self.device: self._freqs_cis = self._freqs_cis.to(device=self.device) return self._freqs_cis def forward( self, images: list[torch.Tensor], ) -> torch.Tensor: """ Args: images: list of N_img images of variable sizes, each of shape (C, H, W) Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings positions = position_meshgrid(patch_embeds_list).to(self.device) freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] # pass through Transformer with a block diagonal mask delimiting images if USE_XFORMERS_OPS: mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) else: from transformers.models.pixtral.modeling_pixtral import ( generate_block_attention_mask, ) mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # squeeze dim 0 and split into separate tensors for each image return torch.split(out.squeeze(0), embed_sizes) class VisionLanguageAdapter(nn.Module): def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) self.w_in = nn.Linear( args.hidden_size, dim, bias=args.adapter_bias, ) self.gelu = nn.GELU() self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) class PatchMerger(nn.Module): """ Learned merging of spatial_merge_size ** 2 patches """ def __init__( self, vision_encoder_dim: int, spatial_merge_size: int, use_mlp_bias: bool = False, ) -> None: super().__init__() mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim self.merging_layer = nn.Linear( mlp_input_dim, vision_encoder_dim, bias=use_mlp_bias, ) def forward( self, x: torch.Tensor, image_sizes: list[tuple[int, int]] ) -> torch.Tensor: # image_sizes specified in tokens assert sum([h * w for h, w in image_sizes]) == len(x) # x is (N, vision_encoder_dim) x = self.permute(x, image_sizes) # x is (N / spatial_merge_size ** 2, # vision_encoder_dim * spatial_merge_size ** 2) x = self.merging_layer(x) # x is (N / spatial_merge_size ** 2, vision_encoder_dim) return x def permute( self, x: torch.Tensor, image_sizes: list[tuple[int, int]], ) -> torch.Tensor: """ Args: x: (N, D) where N is flattened and concatenated patch tokens for all images image_sizes: list of tuple of (height, width) in tokens for each image Returns: image_features: reorders patch tokens so each grid of (spatial_merge_size, spatial_merge_size) is contiguous. now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) """ sub_grids = get_sub_grids( x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size ) # list of [d x sub_grid_size x sub_grid_size x n_patches] permuted_tensor: list[torch.Tensor] = [] for grid in sub_grids: n_patches = grid.shape[-1] permuted_tensor.append( grid.view(-1, n_patches).t() ) # n_patches x d * sub_grid_size * sub_grid_size return torch.cat( permuted_tensor, dim=0 ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) def get_sub_grids( x: torch.Tensor, image_sizes: list[tuple[int, int]], spatial_merge_size: int, ) -> list[torch.Tensor]: # image_sizes specified in tokens tokens_per_image = [h * w for h, w in image_sizes] d = x.shape[-1] all_img_sub_grids: list[torch.Tensor] = [] sub_grid_size = spatial_merge_size for image_index, image_tokens in enumerate(x.split(tokens_per_image)): # Reshape image_tokens into a 2D grid h, w = image_sizes[image_index] image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[ None, :, :, : ] # 1 x d x h x w sub_grids = torch.nn.functional.unfold( image_grid, kernel_size=sub_grid_size, stride=sub_grid_size ) sub_grids = sub_grids.view( 1, d, sub_grid_size, sub_grid_size, -1 ) # 1 x d x sub_grid_size x sub_grid_size x n_patches all_img_sub_grids.append(sub_grids[0]) return all_img_sub_grids #### HF Transformers version of Pixtral #### # Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py # This model follows the Llava family, meaning image embeddings are placed # instead of the `[IMG]` token placeholders. # The model uses [`PixtralVisionModel`] for its vision encoder, # and [`MistralForCausalLM`] for its language decoder. class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: ncols, nrows = self.get_patch_grid_size( image_width=image_width, image_height=image_height, ) return ncols * nrows def get_image_size(self) -> int: return self.vision_config.image_size def get_patch_size(self) -> int: # spatial_merge_size is needed for Mistral3 spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1) return self.vision_config.patch_size * spatial_merge_size def get_patch_grid_length(self) -> int: image_size, patch_size = self.get_image_size(), self.get_patch_size() # Since interpolation is applied, the image size need not be divisible # assert image_size % patch_size == 0 return image_size // patch_size # Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99 def get_patch_grid_size( self, *, image_width: int, image_height: int, ) -> tuple[int, int]: max_width = max_height = self.get_image_size() patch_width = patch_height = self.get_patch_size() ratio = max(image_width / max_width, image_height / max_height) if ratio > 1: image_width = int(math.floor(image_width / ratio)) image_height = int(math.floor(image_height / ratio)) nrows, ncols = _get_pixtral_hf_num_image_tokens( (image_height, image_width), (patch_height, patch_width), ) # type: ignore return ncols, nrows class PixtralHFMLP(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() use_data_parallel = is_vit_use_data_parallel() assert config.intermediate_size is not None self.gate_up_proj = MergedColumnParallelLinear( input_size=config.hidden_size, output_sizes=[config.intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", disable_tp=use_data_parallel, ) self.down_proj = RowParallelLinear( input_size=config.intermediate_size, output_size=config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj", disable_tp=use_data_parallel, ) self.act_and_mul = get_act_and_mul_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_and_mul(gate_up) x, _ = self.down_proj(x) return x class PixtralHFAttention(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() self.config = config assert not config.hidden_size % config.num_attention_heads self.total_num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads assert self.total_num_heads * self.head_dim == config.hidden_size use_data_parallel = is_vit_use_data_parallel() self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", disable_tp=use_data_parallel, ) self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", disable_tp=use_data_parallel, ) self.tp_size = ( 1 if use_data_parallel else get_tensor_model_parallel_world_size() ) self.n_heads = divide(config.num_attention_heads, self.tp_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: batch, patches, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) q, k, v = qkv_states.chunk(3, dim=-1) # Transpose q and k to apply HF's Rotary Position Embedding q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(batch, patches, self.n_heads, self.head_dim) cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) if USE_XFORMERS_OPS: # Transpose q and k back for attention q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask) else: v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask ) out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) attn_output, _ = self.o_proj(out) return attn_output, None class PixtralHFTransformerBlock(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) self.attention = PixtralHFAttention( config, quant_config=quant_config, prefix=f"{prefix}.attention", ) self.feed_forward = PixtralHFMLP( config, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: r, _ = self.attention.forward( self.attention_norm(hidden_states), attention_mask=attention_mask, position_embeddings=position_embeddings, ) h = hidden_states + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r return out class PixtralHFTransformer(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, num_hidden_layers_override: int | None = None, prefix: str = "", ) -> None: super().__init__() if num_hidden_layers_override is None: num_hidden_layers = config.num_hidden_layers else: num_hidden_layers = num_hidden_layers_override self.layers = nn.ModuleList( [ PixtralHFTransformerBlock( config=config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", ) for layer_idx in range(num_hidden_layers) ] ) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, return_all_hidden_states: bool, ) -> torch.Tensor: hidden_states_pool = [x] for layer in self.layers: x = layer(x, attention_mask, position_embeddings) if return_all_hidden_states: hidden_states_pool.append(x) # If we have multiple feature sample layers, we return all hidden # states in order and grab the ones we need by index. if return_all_hidden_states: return hidden_states_pool return x class PixtralHFVisionModel(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.patch_conv = Conv2dLayer( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size, bias=False, ) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.transformer = PixtralHFTransformer( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.transformer", ) num_hidden_layers = config.num_hidden_layers if len(self.transformer.layers) > config.num_hidden_layers: raise ValueError( f"The original encoder only has {num_hidden_layers} " f"layers, but you requested {len(self.transformer.layers)} " "layers." ) if require_post_norm is True: msg = "PixtralHFVisionModel does not have post-layernorm" raise ValueError(msg) self.dtype = next(self.parameters()).dtype self.device = next(self.parameters()).device self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device) def forward( self, pixel_values: list[torch.Tensor], *, select_layers: list[int] | None = None, feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> tuple[torch.Tensor, ...]: """ Args: pixel_values: Each image to be processed will be a separate tensor in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially select_layers: Layer indices whose features should be concatenated and used as the visual encoder output. If none are provided, the last layer is used. Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values ] patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size, ).to(self.device) position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) if USE_XFORMERS_OPS: attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) else: from transformers.models.pixtral.modeling_pixtral import ( generate_block_attention_mask, ) attention_mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) out = self.transformer( patch_embeds, attention_mask, position_embedding, return_all_hidden_states=select_layers is not None, ) out = resolve_visual_encoder_outputs( out, None, select_layers=select_layers, max_possible_layers=self.config.num_hidden_layers, feature_select_strategy=feature_select_strategy, ) # squeeze dim 0 and split into separate tensors for each image return torch.split(out.squeeze(0), embed_sizes) # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() layer_count = len(self.transformer.layers) for name, loaded_weight in weights: # omit layers when num_hidden_layers_override is set if name.startswith("transformer.layers"): layer_idx = int(name.split(".")[2]) if layer_idx >= layer_count: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params