# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # -------------------------------------------------------- # InternS1 # Copyright (c) 2025 Shanghai AI Lab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence from typing import Literal, Optional, TypedDict, Union import torch import torch.nn as nn from transformers import InternVLProcessor, PretrainedConfig from transformers.activations import ACT2FN from transformers.models.got_ocr2.image_processing_got_ocr2_fast import ( GotOcr2ImageProcessorFast) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) class InternS1MultiModalProjector(nn.Module): def __init__(self, config): super().__init__() self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio)**2) self.linear_1 = nn.Linear( config.vision_config.hidden_size * int(1 / config.downsample_ratio)**2, config.text_config.hidden_size) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) def forward(self, image_features): hidden_states = self.layer_norm(image_features) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class InternS1ImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor """ Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ class InternS1ImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: Union[torch.Tensor, list[torch.Tensor]] """ A tensor of shape `(num_images, total_image_feature_size, hidden_size)` or a list of tensors of shape `(total_image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageEmbeddingInputs] class InternS1VideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] pixel_values: torch.Tensor """ Shape: `(batch_size * num_video * num_frames, num_channels, height, width)` """ num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" class InternS1VideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] data: Union[torch.Tensor, list[torch.Tensor]] """ A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` or a list of tensors of shape `(total_video_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoEmbeddingInputs] def resolve_interns1_min_max_num( min_dynamic_patch: int, max_dynamic_patch: int, dynamic_image_size: bool, use_thumbnail: bool, ) -> tuple[int, int]: min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if use_thumbnail and max_dynamic_patch != 1: max_dynamic_patch += 1 return min_dynamic_patch, max_dynamic_patch def get_interns1_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: target_ratios = {(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if min_num <= i * j <= max_num} return sorted(target_ratios, key=lambda x: x[0] * x[1]) class InternS1ProcessingInfo(BaseProcessingInfo): """Basic image-only ProcessingInfo for InternS1-style models.""" def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: return self.ctx.get_hf_processor(InternVLProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_num_image_tokens( self, *, image_width: int, image_height: int, processor: Optional['GotOcr2ImageProcessorFast'] = None, ) -> int: if processor is None: processor = self.get_hf_processor().image_processor if not isinstance(processor, GotOcr2ImageProcessorFast): raise ValueError(f'GotOcr2ImageProcessorFast is expected but got ' f'{type(processor)}') num_image_patches = processor.get_number_of_image_tokens( image_height, image_width, images_kwargs=dict()) num_image_tokens = self.get_hf_processor( ).image_seq_length * num_image_patches return num_image_tokens def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None): image_processor = self.get_hf_processor().image_processor min_dynamic_patch = image_processor.min_patches max_dynamic_patch = image_processor.max_patches # HF format's InternVL processor uses `crop_to_patches` which is # equivalent to `use_thumbnail` in original format. use_thumbnail = image_processor.crop_to_patches dynamic_image_size = True min_num, max_num = resolve_interns1_min_max_num( min_dynamic_patch, max_dynamic_patch, dynamic_image_size, use_thumbnail=use_thumbnail) return get_interns1_target_ratios(min_num, max_num) def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() hf_config = self.ctx.get_hf_config() base_height, base_width = hf_config.vision_config.image_size target_ratios = self.resolve_target_ratios() largest_feature_size, largest_feature_pinpoint = 0, None for wr, hr in target_ratios: width, height = base_width * wr, base_height * hr feat_size = self.get_num_image_tokens( image_width=width, image_height=height, processor=processor.image_processor, ) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, height=height) assert not (largest_feature_size == 0 or largest_feature_pinpoint is None), ("Cannot have a largest feature size of 0!") return largest_feature_pinpoint def get_max_image_tokens(self) -> int: processor = self.get_hf_processor() target_width, target_height = self.get_image_size_with_most_features() return self.get_num_image_tokens( image_width=target_width, image_height=target_height, processor=processor.image_processor, ) class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo] ): """Basic image-only DummyInputsBuilder for InternS1-style models.""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) image_token = self.info.get_hf_processor().image_token return image_token * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } class InternS1MultiModalProcessor( BaseMultiModalProcessor[InternS1ProcessingInfo]): """ Basic image-only MultiModalProcessor for InternS1-style models.""" def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) image_token_id = hf_processor.image_token_id # Since there may be extra tokens in the feature placeholders, # we need to pass the image token ID to the model to select the # tokens to merge from the vision encoder outputs processed_outputs["image_token_id"] = torch.tensor(image_token_id) images = mm_data.get('images', None) image_processor = self.info.get_hf_processor().image_processor if images is not None: image_inputs = image_processor(images=images) image_num_patches = image_inputs.pop("num_patches") if not isinstance(image_num_patches, list): raise ValueError( f'num_patches is supposed to be list, but got ' f'{type(image_num_patches)}') image_num_patches = torch.tensor(image_num_patches) processed_outputs['image_num_patches'] = image_num_patches return processed_outputs def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) num_images = len(image_num_patches) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) img_context_token = hf_processor.image_token start_image_token = hf_processor.start_image_token end_image_token = hf_processor.end_image_token def get_replacement(item_idx: int): images = mm_items.get_items( "image", (ImageEmbeddingItems, ImageProcessorItems)) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) feature_size = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, processor=hf_processor.image_processor, ) repl_features = img_context_token * feature_size repl_full = start_image_token + repl_features + end_image_token return PromptUpdateDetails.select_text(repl_full, img_context_token) return [ PromptReplacement( modality="image", target=img_context_token, replacement=get_replacement, ) ] @MULTIMODAL_REGISTRY.register_processor( InternS1MultiModalProcessor, info=InternS1ProcessingInfo, dummy_inputs=InternS1DummyInputsBuilder) class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", }) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: # transformers InternVLProcessor uses as the seperator # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116 if modality.startswith("image"): return '' if modality.startswith("video"): return "