# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from math import sqrt from typing import Annotated, Any, Literal, TypeAlias import torch import torch.nn as nn import torch.nn.functional as F from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( BaseDummyInputsBuilder, BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig from vllm.transformers_utils.processors.step3_vl import ( MAX_IMAGE_SIZE, Step3VLImageProcessor, Step3VLProcessor, ) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model class Step3VLImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - c: Number of channels (3) - h: Height - w: Width - bnp: Batch size * number of images * number of patches - hp: Height of patch - wp: Width of patch """ type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] class Step3VLImageEmbeddingInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - f: Image feature size - h: Hidden size (must match the hidden size of language model backbone) """ type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs class Step3VLProcessingInfo(BaseProcessingInfo): def get_image_processor(self, **kwargs): config = self.get_hf_config() kwargs.setdefault( "enable_patch", getattr(config.vision_config, "enable_patch", True), ) return Step3VLImageProcessor(**kwargs) def get_hf_processor(self) -> Step3VLProcessor: return Step3VLProcessor( tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor(), ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_max_image_tokens(self) -> int: image_processor = self.get_image_processor() target_width, target_height = self.get_image_size_with_most_features() return image_processor.get_num_image_tokens(target_width, target_height) def get_mm_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: return {"image": self.get_max_image_tokens()} def get_image_size_with_most_features(self) -> ImageSize: return ImageSize(MAX_IMAGE_SIZE, MAX_IMAGE_SIZE) class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return "" * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") return { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ) } class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_placeholder_token_id = hf_processor.image_token_id def get_replacement_step1o(item_idx: int): out_item = out_mm_kwargs["image"][item_idx] num_patches = int(out_item["num_patches"].data) patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor.get_image_repl_feature_ids( 1, num_patches, patch_newline_mask.tolist() ) return PromptUpdateDetails.select_token_id( seq=image_repl_ids, embed_token_id=image_placeholder_token_id, ) return [ PromptReplacement( modality="image", target=[image_placeholder_token_id], replacement=get_replacement_step1o, ) ] def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( pixel_values=MultiModalFieldConfig.batched("image"), patch_pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", num_patches ), num_patches=MultiModalFieldConfig.batched("image"), patch_newline_mask=MultiModalFieldConfig.flat_from_sizes( "image", num_patches ), ) def get_abs_pos(abs_pos, tgt_size): dim = abs_pos.size(-1) abs_pos_new = abs_pos.squeeze(0) cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: old_pos_embed = ( old_pos_embed.view(1, src_size, src_size, dim) .permute(0, 3, 1, 2) .contiguous() ) old_pos_embed = old_pos_embed.to(torch.float32) new_pos_embed = F.interpolate( old_pos_embed, size=(tgt_size, tgt_size), mode="bicubic", antialias=True, align_corners=False, ).to(dtype) new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) return vision_pos_embed else: return abs_pos class Step3VisionEmbeddings(nn.Module): def __init__(self, config: Step3VisionEncoderConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim)) self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=True, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.pad_tp_size = 4 # hard code for padding # To load the pretrained weights, we still use P+1 as the seqlen self.position_embedding = torch.nn.Embedding( self.num_patches + 1, self.embed_dim ) self.register_buffer( "position_ids", torch.arange(self.num_patches + 1).expand((1, -1)), persistent=False, ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding( pixel_values ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # pad class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = embeddings + get_abs_pos( self.position_embedding(self.position_ids), patch_embeds.size(1) ) embeddings = torch.cat( [ embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1), embeddings, ], dim=1, ) return embeddings class Step3VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, config, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.embed_dim = config.hidden_size self.total_num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.total_num_heads self.scale = self.head_dim**-0.5 use_data_parallel = is_vit_use_data_parallel() tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.q_size = self.num_heads * self.head_dim self.qkv_proj = QKVParallelLinear( self.embed_dim, self.head_dim, self.total_num_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( self.embed_dim, self.embed_dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.out_proj", disable_tp=use_data_parallel, ) # Use unified MMEncoderAttention with automatic backend selection self.attn = MMEncoderAttention( self.num_heads, self.head_dim, self.scale, prefix=f"{prefix}.attn", ) def forward( self, hidden_states: torch.Tensor, ): """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() # get query proj qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) # Use unified MMEncoderAttention with automatic backend selection attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) return attn_output class Step3VisionMLP(nn.Module): def __init__( self, config, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) use_data_parallel = is_vit_use_data_parallel() self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", disable_tp=use_data_parallel, ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc2", disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states, _ = self.fc2(hidden_states) return hidden_states class Step3VisionEncoderLayer(nn.Module): def __init__( self, config: Step3VisionEncoderConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.embed_dim = config.hidden_size self.self_attn = Step3VisionAttention( config, quant_config, prefix=f"{prefix}.self_attn", ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Step3VisionMLP( config, quant_config, prefix=f"{prefix}.mlp", ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, ) -> torch.FloatTensor: hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states)) hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states)) return hidden_states class Step3VisionEncoder(nn.Module): def __init__( self, config: Step3VisionEncoderConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.layers = nn.ModuleList( [ Step3VisionEncoderLayer( config, quant_config, prefix=f"{prefix}.layers.{i}", ) for i in range(config.num_hidden_layers) ] ) def forward( self, inputs_embeds, ): hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) return hidden_states class Step3VisionTransformer(nn.Module): def __init__( self, config: Step3VisionEncoderConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.use_data_parallel = is_vit_use_data_parallel() self.image_size = config.image_size self.embeddings = Step3VisionEmbeddings(config) self.transformer = Step3VisionEncoder( config, quant_config, prefix=f"{prefix}.transformer", ) def forward( self, pixel_values: torch.Tensor, ): hidden_states = self.embeddings(pixel_values) if self.use_data_parallel: hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer) else: hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states @MULTIMODAL_REGISTRY.register_processor( Step3VLMultiModalProcessor, info=Step3VLProcessingInfo, dummy_inputs=Step3VLDummyInputsBuilder, ) class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.": "language_model.model.", "lm_head.": "language_model.lm_head.", } ) supports_encoder_tp_data = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "" raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" # NOTE: This behavior is consistent with the previous OOV handling, # but does not currently handle the start/stop toks around the # image features ( ) # See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323 # # If this becomes an issue or we refactor to handle this using the # processor info in the future, it would probably be best to handle # those too. self.configure_mm_token_handling( self.config.text_config.vocab_size, [self.config.image_token_id], ) with self._mark_tower_model(vllm_config, "image"): self.vision_model = Step3VisionTransformer( config.vision_config, None, prefix=maybe_prefix(prefix, "vision_model"), ) self.vit_downsampler = Conv2dLayer( config.vision_config.hidden_size, config.vision_config.output_hidden_size, kernel_size=2, stride=config.understand_projector_stride, ) self.vit_downsampler2 = Conv2dLayer( config.vision_config.output_hidden_size, config.vision_config.output_hidden_size * 2, kernel_size=3, stride=2, padding=1, ) self.vit_large_projector = nn.Linear( config.vision_config.output_hidden_size * 2, config.hidden_size, bias=config.projector_bias, ) with self._mark_language_model(vllm_config): self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype def _parse_and_validate_image_input( self, **kwargs: object ) -> Step3VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) patch_pixel_values = kwargs.pop("patch_pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None and patch_pixel_values is not None: return Step3VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values.to(self.dtype), patch_pixel_values=patch_pixel_values.to(self.dtype), num_patches=num_patches, ) if image_embeds is not None: return Step3VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds.to(self.dtype), ) raise AssertionError("This line should be unreachable.") def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: B, P = image_features.shape[:2] HW = int(sqrt(P)) image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) image_features = self.vit_downsampler(image_features) image_features = self.vit_downsampler2(image_features) n_dim = image_features.size(1) image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1) image_features = self.vit_large_projector(image_features) return image_features def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor: return self.vision_model(input_tensor)[:, 4:] def _process_image_input( self, image_input: Step3VLImageInputs ) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": image_features = image_input["image_embeds"] else: image_features = self._get_vision_model_output(image_input["pixel_values"]) patch_image_features = ( self._get_vision_model_output(image_input["patch_pixel_values"]) if len(image_input["patch_pixel_values"]) > 0 else None ) num_patches = image_input["num_patches"] image_features = self._process_image_features(image_features) patch_image_features = ( self._process_image_features(patch_image_features) if patch_image_features is not None else None ) merged_image_features = [] cur_patch_idx = 0 for i, num_patch in enumerate(num_patches): cur_feature = [] if num_patch > 0: patch_slice = patch_image_features[ cur_patch_idx : cur_patch_idx + num_patch ] cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) cur_feature.append(image_features[i].view(-1, image_features.shape[-1])) cur_patch_idx += num_patch merged_image_features.append( torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0] ) return merged_image_features def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: return super().embed_input_ids(input_ids) return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, ) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)