# SPDX-License-Identifier: Apache-2.0 import math from collections.abc import Iterable, Mapping, Sequence from typing import Any, Literal, Optional, TypedDict import torch from torch import nn from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BoundPromptUpdate, PlaceholderFeaturesInfo, PromptReplacement, PromptTargetMatch, PromptUpdate, PromptUpdateDetails, find_mm_placeholders, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) class Gemma3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor """ Shape: `(num_patches_total, num_channels, height, width)` `num_patches_total` is the total number of patches over each image over each prompt in the batch. """ num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" Gemma3ImageInputs = Gemma3ImagePixelInputs class Gemma3ProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(Gemma3Config) def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def _resolve_image_kwargs( self, processor: Gemma3Processor, keys: set[str], ) -> dict[str, Any]: image_processor = processor.image_processor kwargs = processor._merge_kwargs( Gemma3ProcessorKwargs, tokenizer_init_kwargs=processor.tokenizer.init_kwargs, ) images_kwargs = kwargs["images_kwargs"] def _resolve_kw(key: str): val = getattr(image_processor, key) if val is None: val = images_kwargs[key] return val return {k: _resolve_kw(k) for k in keys} def get_num_crops( self, *, image_width: int, image_height: int, processor: Optional[Gemma3Processor], ) -> int: if processor is None: processor = self.get_hf_processor() images_kwargs = self._resolve_image_kwargs( processor, { "do_pan_and_scan", "pan_and_scan_min_crop_size", "pan_and_scan_max_num_crops", "pan_and_scan_min_ratio_to_activate" }) do_pan_and_scan = images_kwargs["do_pan_and_scan"] pan_and_scan_min_crop_size = images_kwargs[ "pan_and_scan_min_crop_size"] pan_and_scan_max_num_crops = images_kwargs[ "pan_and_scan_max_num_crops"] pan_and_scan_min_ratio_to_activate = images_kwargs[ "pan_and_scan_min_ratio_to_activate"] if not do_pan_and_scan: return 0 if envs.VLLM_USE_V1: logger.warning_once( "`do_pan_and_scan=True` has suboptimal results on V1 " "because of the simplified attention pattern being used.") # Based on Gemma3ImageProcessor.pan_and_scan if image_width >= image_height: if image_width / image_height < pan_and_scan_min_ratio_to_activate: return 0 num_crops_w = min( int(math.floor(image_width / pan_and_scan_min_crop_size)), int(math.floor(image_width / image_height + 0.5)), ) num_crops_w = max(2, num_crops_w) num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) num_crops_h = 1 else: if image_height / image_width < pan_and_scan_min_ratio_to_activate: return 0 num_crops_h = min( int(math.floor(image_height / pan_and_scan_min_crop_size)), int(math.floor(image_height / image_width + 0.5)), ) num_crops_h = max(2, num_crops_h) num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) num_crops_w = 1 crop_size_w = int(math.ceil(image_width / num_crops_w)) crop_size_h = int(math.ceil(image_height / num_crops_h)) if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: return 0 return num_crops_w * num_crops_h def get_image_repl( self, *, image_width: int, image_height: int, processor: Optional[Gemma3Processor], ) -> PromptUpdateDetails[str]: if processor is None: processor = self.get_hf_processor() boi_token = processor.boi_token num_crops = self.get_num_crops( image_width=image_width, image_height=image_height, processor=processor, ) if num_crops == 0: image_text = boi_token else: crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) image_text = ( f"Here is the original image {boi_token} and here are some " f"crops to help you see better {crops_image_tokens}") repl_full = image_text.replace(boi_token, processor.full_image_sequence) tokenizer = processor.tokenizer vocab = tokenizer.get_vocab() image_token_id = vocab[tokenizer.image_token] return PromptUpdateDetails.select_token_id(repl_full, image_token_id) def get_num_image_tokens( self, *, image_width: int, image_height: int, processor: Optional[Gemma3Processor], ) -> int: if processor is None: processor = self.get_hf_processor() num_crops = self.get_num_crops( image_width=image_width, image_height=image_height, processor=processor, ) image_seq_len = processor.image_seq_length return (num_crops + 1) * image_seq_len def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() images_kwargs = self._resolve_image_kwargs( processor, {"pan_and_scan_max_num_crops"}) max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] # Result in the max possible feature size (h:w = max_num_crops:1) return ImageSize(height=50 * max_num_crops, width=50) class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) processor = self.info.get_hf_processor() image_token = processor.boi_token return image_token * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt, mm_data, mm_kwargs, ) # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: parsed_images = (self._get_data_parser().parse_mm_data({ "image": images }).get_items("image", ImageProcessorItems)) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] hf_processor = self.info.get_hf_processor(**mm_kwargs) num_crops = [ self.info.get_num_crops(image_width=size.width, image_height=size.height, processor=hf_processor) for size in image_sizes ] processed_outputs["num_crops"] = torch.tensor(num_crops) return processed_outputs def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: num_crops = hf_inputs.get("num_crops", torch.empty(0)) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", num_crops + 1), num_crops=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.boi_token def get_replacement_gemma3(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) return self.info.get_image_repl( image_width=image_size.width, image_height=image_size.height, processor=hf_processor, ) return [ PromptReplacement( modality="image", target=image_token, replacement=get_replacement_gemma3, ) ] def _apply_token_matches( self, prompt: list[int], mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[int]: token_ids = super()._apply_token_matches( prompt, mm_matches, mm_item_counts, ) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" # tokens, we have to combine them to be consistent with # the output of the tokenizer tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() newline_1 = vocab["\n"] newline_2 = vocab["\n\n"] newline_3 = vocab["\n\n\n"] newline_4 = vocab["\n\n\n\n"] token_ids = replace_token_matches( token_ids, [newline_1, newline_2], [newline_3], ) token_ids = replace_token_matches( token_ids, [newline_2, newline_1], [newline_3], ) token_ids = replace_token_matches( token_ids, [newline_2, newline_2], [newline_4], ) return token_ids def _find_mm_placeholders( self, mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], mm_item_counts: Mapping[str, int], ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() newline_1 = vocab["\n"] newline_2 = vocab["\n\n"] newline_3 = vocab["\n\n\n"] newline_4 = vocab["\n\n\n\n"] def get_repl_toks(tok: int) -> list[int]: if tok == newline_3: return [newline_1, newline_2] if tok == newline_4: return [newline_2, newline_2] return [tok] repl_token_ids = list[int]() repl_orig_idxs = list[int]() for orig_idx, orig_tok in enumerate(new_token_ids): repl_toks = get_repl_toks(orig_tok) repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, mm_item_counts) return { modality: [ PlaceholderFeaturesInfo( modality=p.modality, item_idx=p.item_idx, start_idx=repl_orig_idxs[p.start_idx], tokens=p.tokens, is_embed=p.is_embed, ) for p in placeholders ] for modality, placeholders in repls.items() } class Gemma3MultiModalProjector(nn.Module): def __init__(self, config: Gemma3Config): super().__init__() self.mm_input_projection_weight = nn.Parameter( torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)) self.mm_soft_emb_norm = GemmaRMSNorm( config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( batch_size, seq_length, self.patches_per_image, self.patches_per_image) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) pooled_vision_outputs = pooled_vision_outputs.flatten(2) pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) projected_vision_outputs = torch.matmul( normed_vision_outputs, self.mm_input_projection_weight) return projected_vision_outputs.type_as(vision_outputs) @MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, info=Gemma3ProcessingInfo, dummy_inputs=Gemma3DummyInputsBuilder) class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config self.sliding_window = getattr(config.text_config, "interleaved_sliding_window", None) self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, prefix=maybe_prefix( prefix, "vision_tower")) self.multi_modal_projector = Gemma3MultiModalProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), architectures=["Gemma3ForCausalLM"], ) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @property def dtype(self): return next(self.parameters()).dtype def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: image_size = self.config.vision_config.image_size expected_dims = (3, image_size, image_size) if data.shape[1:] != expected_dims: raise ValueError( "The expected shape of pixel values per image per batch is " f"{expected_dims}. You supplied {tuple(data.shape)}.") return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) num_crops = kwargs.pop("num_crops", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: return None if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") pixel_values = flatten_bn(pixel_values, concat=True) num_crops = flatten_bn(num_crops, concat=True) return Gemma3ImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values(pixel_values), num_patches=num_crops + 1, ) def _image_pixels_to_features( self, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: return vision_tower(pixel_values) def _process_image_input( self, image_input: Gemma3ImageInputs, ) -> list[torch.Tensor]: assert self.vision_tower is not None pixel_values = image_input["pixel_values"] num_patches = image_input["num_patches"] image_features = self._image_pixels_to_features( self.vision_tower, pixel_values, ) image_embeds = self.multi_modal_projector(image_features) return [ e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) ] def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None return self._process_image_input(image_input) def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, self.config.image_token_index, ) return inputs_embeds def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) if vision_embeddings is not None: kwargs = self.prepare_attn_masks( input_ids, positions, mask_dtype=self.dtype, **kwargs, ) input_ids = None hidden_states = self.language_model.model(input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds, **kwargs) return hidden_states def prepare_attn_masks( self, input_ids: torch.Tensor, positions: torch.Tensor, mask_dtype: torch.dtype, **kwargs, ): kwargs["has_images"] = True # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. # This is a HACK. Fix this. start_idices = (positions == 0).cpu().nonzero() num_seqs = len(start_idices) seq_lens = [] for i in range(num_seqs): start_idx = start_idices[i].item() if i < num_seqs - 1: end_idx = start_idices[i + 1].item() else: end_idx = len(input_ids) seq_lens.append(end_idx - start_idx) kwargs["seq_lens"] = seq_lens global_attn_masks = [] local_attn_masks = [] start_idx = 0 for seq_len in seq_lens: end_idx = start_idx + seq_len input_token_ids = input_ids[start_idx:end_idx] start_idx = end_idx # Create a global causal mask. global_attn_mask = torch.empty( 1, 1, seq_len, seq_len, dtype=mask_dtype, device=input_ids.device, ) global_attn_mask.fill_(float("-inf")) # Fill the lower triangle with 0. global_attn_mask = global_attn_mask.triu(diagonal=1) # Consider the bidirectional attention between image tokens. img_mask = torch.zeros_like(global_attn_mask) img_pos = (input_token_ids == self.config.image_token_index) img_mask[:, :, :, img_pos] += 1 img_mask[:, :, img_pos, :] += 1 global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) if self.sliding_window is not None: # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.tril(local_attn_mask, diagonal=-self.sliding_window) local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks return kwargs def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", tower_model="vision_tower")