# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ColModernVBERT: multimodal late-interaction retrieval model. Combines SigLIP vision encoder + ModernBERT text encoder with a pixel shuffle connector and ColBERT-style 128-dim per-token embeddings. Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged """ from collections.abc import Iterable, Mapping, Sequence import torch from torch import nn from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( BaseDummyInputsBuilder, BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptReplacement, PromptUpdate, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig from .interfaces import ( MultiModalEmbeddings, SupportsLateInteraction, SupportsMultiModal, ) from .interfaces_base import default_pooling_type from .modernbert import ModernBertEmbeddings, ModernBertLayer from .siglip import SiglipVisionModel from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix # --------------------------------------------------------------------------- # Connector: pixel shuffle + simple linear projection # --------------------------------------------------------------------------- class ColModernVBertConnector(nn.Module): """Pixel shuffle spatial reduction followed by a linear projection. Reduces the vision encoder's token count by ``factor^2`` via pixel-shuffle spatial rearrangement, then projects the concatenated channels to the text encoder's hidden size with a single bias-free linear layer. """ def __init__(self, config: ColModernVBertConfig): super().__init__() self.pixel_shuffle_factor = config.pixel_shuffle_factor vision_hidden_size = config.vision_config.hidden_size input_size = vision_hidden_size * (self.pixel_shuffle_factor**2) output_size = config.hidden_size self.proj = nn.Linear(input_size, output_size, bias=False) def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor: """Spatial rearrangement that reduces seq length by factor^2.""" batch_size, seq_length, hidden_size = features.shape height = width = int(seq_length**0.5) factor = self.pixel_shuffle_factor # Reshape to (B, H, W, C) features = features.view(batch_size, height, width, hidden_size) # Reshape to (B, H/f, f, W/f, f, C) features = features.view( batch_size, height // factor, factor, width // factor, factor, hidden_size ) # Permute to (B, H/f, W/f, f, f, C) features = features.permute(0, 1, 3, 2, 4, 5) # Reshape to (B, H/f, W/f, C * f^2) new_hidden_size = hidden_size * (factor**2) features = features.reshape( batch_size, height // factor, width // factor, new_hidden_size ) return features def forward(self, features: torch.Tensor) -> torch.Tensor: features = self.pixel_shuffle(features) batch_size = features.shape[0] features = features.reshape(batch_size, -1, features.shape[-1]) return self.proj(features) # --------------------------------------------------------------------------- # Multimodal processing # --------------------------------------------------------------------------- class ColModernVBertProcessingInfo(BaseProcessingInfo): def get_hf_config(self) -> ColModernVBertConfig: return self.ctx.get_hf_config(ColModernVBertConfig) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: config = self.get_hf_config() size = config.vision_config.image_size return ImageSize(width=size, height=size) def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: return self.get_hf_config().image_seq_len class ColModernVBertDummyInputsBuilder( BaseDummyInputsBuilder[ColModernVBertProcessingInfo], ): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") return { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ) } class ColModernVBertMultiModalProcessor( BaseMultiModalProcessor[ColModernVBertProcessingInfo], ): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() text_encoding = tokenizer( prompt, return_tensors="pt", **tok_kwargs, ) result = BatchFeature(data=dict(text_encoding)) images = mm_data.get("images") if images: from transformers import Idefics3ImageProcessor image_processor = Idefics3ImageProcessor.from_pretrained( self.info.ctx.model_config.model, revision=self.info.ctx.model_config.revision, ) image_outputs = image_processor( images=images, do_image_splitting=False, return_tensors="pt", ) result.update(image_outputs) return result def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], ) -> bool: return False def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: config = self.info.get_hf_config() image_token_id = config.image_token_id num_tokens = config.image_seq_len def get_replacement(item_idx: int): return [image_token_id] * num_tokens return [ PromptReplacement( modality="image", target=PromptIndexTargets.start(), replacement=get_replacement, ), ] # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- @MULTIMODAL_REGISTRY.register_processor( ColModernVBertMultiModalProcessor, info=ColModernVBertProcessingInfo, dummy_inputs=ColModernVBertDummyInputsBuilder, ) @default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") class ColModernVBertForRetrieval( nn.Module, SupportsMultiModal, SupportsLateInteraction ): """ColModernVBERT multimodal late-interaction retrieval model. Architecture: Image -> SiglipVisionModel -> ColModernVBertConnector ↓ Text -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm ↓ custom_text_proj → L2 norm ↓ per-token 128-d embeddings """ is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: ColModernVBertConfig = vllm_config.model_config.hf_config self.config = config text_config = config.text_config quant_config = vllm_config.quant_config # --- Vision encoder (reuses SiglipVisionModel from siglip.py) --- self.vision_model = SiglipVisionModel( config.vision_config, quant_config, prefix=maybe_prefix(prefix, "vision_model"), ) # --- Connector (pixel shuffle + linear projection) --- self.connector = ColModernVBertConnector(config) # --- Text encoder (built from ModernBERT components directly) --- # We build the components individually rather than wrapping # ``ModernBertModel`` because ``ModernBertEncoderLayer`` reads # ``vllm_config.model_config.hf_config`` which would be # ``ColModernVBertConfig``, not ``ModernBertConfig``. self.text_embeddings = ModernBertEmbeddings(text_config) self.text_layers = nn.ModuleList( [ ModernBertLayer( config=text_config, layer_id=i, prefix=f"{prefix}.text_layers.{i}", ) for i in range(text_config.num_hidden_layers) ] ) self.text_final_norm = nn.LayerNorm( text_config.hidden_size, eps=text_config.norm_eps, bias=text_config.norm_bias, ) # --- ColBERT projection (768 -> 128, with bias) --- self.custom_text_proj = nn.Linear( text_config.hidden_size, config.embedding_dim, bias=True, dtype=vllm_config.model_config.head_dtype, ) # --- Pooler (applies projection + L2 normalize) --- pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = pooler_for_token_embed( pooler_config, projector=self.custom_text_proj, ) # ---- multimodal --------------------------------------------------------- def _get_image_features( self, pixel_values: torch.Tensor, ) -> torch.Tensor: # Idefics3ImageProcessor may return (batch, tiles, C, H, W); # flatten to (batch*tiles, C, H, W) for SiglipVisionModel. if pixel_values.dim() == 5: b, t, c, h, w = pixel_values.shape pixel_values = pixel_values.reshape(b * t, c, h, w) vision_outputs = self.vision_model( pixel_values.to(dtype=self.vision_model.dtype), ) return self.connector(vision_outputs) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: return [] assert isinstance(pixel_values, torch.Tensor) image_features = self._get_image_features(pixel_values) return list(image_features) # ---- forward ------------------------------------------------------------ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.text_embeddings(input_ids, inputs_embeds=inputs_embeds) for layer in self.text_layers: hidden_states = layer(hidden_states, positions) return self.text_final_norm(hidden_states) # ---- weight loading ----------------------------------------------------- # Checkpoint prefix → vLLM param prefix. # More-specific prefixes must appear before shorter ones. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.text_model.layers.": "text_layers.", "model.text_model.embeddings.": "text_embeddings.", "model.text_model.final_norm.": "text_final_norm.", "model.connector.modality_projection.": "connector.", "model.custom_text_proj.": "custom_text_proj.", "model.vision_model.": "vision_model.vision_model.", "model.": "", }, ) # Checkpoint names for DecoupledEmbedding parts _BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight" _EXTRA_EMB = ( "model.text_model.embeddings.tok_embeddings.additional_embedding.weight" ) def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: # DecoupledEmbedding requires concatenating base + additional # embedding tensors before loading, so we extract them first. base_embedding_weight: torch.Tensor | None = None additional_embedding_weight: torch.Tensor | None = None remaining: list[tuple[str, torch.Tensor]] = [] for name, tensor in weights: if name == self._BASE_EMB: base_embedding_weight = tensor elif name == self._EXTRA_EMB: additional_embedding_weight = tensor else: remaining.append((name, tensor)) # Load all non-embedding weights via AutoWeightsLoader loader = AutoWeightsLoader(self) loaded_params = loader.load_weights( remaining, mapper=self.hf_to_vllm_mapper, ) # Concatenate and load DecoupledEmbedding weights if base_embedding_weight is not None: combined = base_embedding_weight if additional_embedding_weight is not None: combined = torch.cat( [base_embedding_weight, additional_embedding_weight], dim=0, ) param_name = "text_embeddings.tok_embeddings.weight" params_dict = dict(self.named_parameters()) if param_name in params_dict: param = params_dict[param_name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) weight_loader(param, combined) loaded_params.add(param_name) elif additional_embedding_weight is not None: raise ValueError( "Found 'text_model.embeddings.tok_embeddings" ".additional_embedding.weight' but not " "'text_model.embeddings.tok_embeddings.weight'" ) # The pooler wraps ``custom_text_proj`` as its head projector. # Mark those params as loaded under the pooler path too. if hasattr(self, "pooler") and hasattr(self.pooler, "head"): head = self.pooler.head projector = getattr(head, "projector", None) if projector is not None and isinstance(projector, nn.Module): for pname, _ in projector.named_parameters(): loaded_params.add(f"pooler.head.projector.{pname}") return loaded_params