# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2024 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" from collections.abc import Iterable, Mapping from contextlib import contextmanager from pathlib import Path from typing import Literal, Optional, Union import regex as re import torch import transformers from packaging.version import Version from torch import nn from transformers import (AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) from vllm.config.multimodal import BaseDummyOptions from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalUUIDDict, PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, flatten_bn, make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) def get_feature_request_tip( model: str, trust_remote_code: bool, ) -> str: hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" url = hf_url if trust_remote_code else gh_url prefix = f"Please open {url} to request support for this feature. " if Path(model).exists(): prefix = "" doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" tip = f"See {doc_url} for instructions on how to add support yourself." return f"{prefix}{tip}" def vllm_flash_attention_forward( # Transformers args module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, # Transformers kwargs scaling: Optional[float] = None, # vLLM kwargs attention_instances: Optional[dict[Attention]] = None, **kwargs): self_attn = attention_instances[module.layer_idx] if scaling is not None: self_attn.impl.scale = float(scaling) hidden = query.shape[-2] query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) return self_attn.forward(query, key, value), None ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module) def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: """ Callable to be passed to `@support_torch_compile`'s `enable_if` argument. Defaults to `True` but is disabled in the following situations: - The model uses dynamic rope scaling. """ enable = True text_config = vllm_config.model_config.hf_config.get_text_config() # Dynamic rope scaling is not compatible with torch.compile rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} if rope_scaling.get("rope_type") == "dynamic": enable = False return enable Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] def replace_linear_class( linear: nn.Linear, style: Style = "replicate", quant_config: Optional[QuantizationConfig] = None, *, prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. Args: linear: `nn.Linear` to be replaced. style: Tensor parallel style of the new linear, e.g. "colwise". quant_config: Quantization config for the new linear. Returns: The new linear. """ if not isinstance(style, str): raise ValueError( f"Unsupported parallel style type {type(style)}, expected str") vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), "colwise_rep": (ColumnParallelLinear, { "gather_output": True }), "rowwise": (RowParallelLinear, {}), "rowwise_rep": (RowParallelLinear, { "input_is_parallel": False }), "replicate": (ReplicatedLinear, {}), }.get(style, (ReplicatedLinear, {})) return vllm_linear_cls( input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, prefix=prefix, return_bias=False, **vllm_linear_kwargs, ) def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: """Replace a Transformers RMSNorm with vLLM's RMSNorm. This method assumes: - Weight is stored as `weight`. - Epsilon is stored as `eps` or `variance_epsilon`. - `with_scale` indicates whether the layer has a weight (Gemma3n only). - `var_hidden_size` is only ever used for Intern vision encoder in vLLM and Transformers doesn't appear to have the same concept. """ kwargs = { "hidden_size": hidden_size, "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), "has_weight": getattr(rms_norm, "with_scale", True) } if (weight := getattr(rms_norm, "weight", None)) is not None: # If weight is a Parameter, get its data tensor weight = getattr(weight, "data", weight) kwargs["dtype"] = weight.dtype else: # No weight, fall back to weightless RMSNorm kwargs["has_weight"] = False return RMSNorm(**kwargs) # Copied from `accelerate` @contextmanager def init_on_device_without_buffers(device: torch.device): """ A context manager under which models are initialized with all parameters on the specified device. However buffers are not initialized on specified device. Args: device (`torch.device`): Device to initialize all parameters on. """ old_register_parameter = nn.Module.register_parameter def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls( module._parameters[name].to(device), **kwargs) tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): def wrapper(*args, **kwargs): kwargs["device"] = device return fn(*args, **kwargs) return wrapper try: nn.Module.register_parameter = register_empty_parameter for torch_function_name in tensor_constructors_to_patch: setattr( torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) yield finally: nn.Module.register_parameter = old_register_parameter for torch_function_name, old_torch_function in ( tensor_constructors_to_patch.items()): setattr(torch, torch_function_name, old_torch_function) class MultiModalProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self): return {"image": None} def get_mm_max_tokens_per_item(self, seq_len, mm_counts): return {"image": self.get_max_image_tokens()} def get_max_image_tokens(self) -> int: width, height = self.get_max_image_size() processor = self.get_hf_processor() multimodal_config = self.ctx.model_config.multimodal_config mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} mm_tokens = processor._get_num_multimodal_tokens( image_sizes=([height, width], ), **mm_processor_kwargs) image_tokens = mm_tokens["num_image_tokens"][0] return image_tokens def get_max_image_size(self): return 10_000, 10_000 # hardcode for arbitrary very large size class MultiModalDummyInputsBuilder( BaseDummyInputsBuilder[MultiModalProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) processor = self.info.get_hf_processor() if "gemma3" in processor.__class__.__name__.lower(): image_token = processor.boi_token else: image_token = getattr(processor, "image_token", "") return image_token * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = self.info.get_max_image_size() image_overrides = mm_options.get("image") if mm_options else None return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images, overrides=image_overrides), } class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ): """ Given the original multi-modal items for this modality and HF-processed data, output the updates to perform. The information returned by this method is used to update token inputs which bypass the HF processor. It is also used to update the output of HF processor if the HF process does not apply prompt updates to text inputs. Moreover, this information is critical to determine the token positions in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` for each multi-modal item. """ return None def _get_mm_fields_config( self, hf_inputs, hf_processor_mm_kwargs, num_image_patches: torch.Tensor = None, ): # HF Processors always return a mask but vLLM doesn't need it hf_inputs.pop("attention_mask", None) mm_fields = { key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) for key in hf_inputs } mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( "image", num_image_patches) mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") return mm_fields def _apply_hf_processor_text_mm( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], ) -> tuple[list[int], BatchFeature, bool]: """ Apply the HF processor on the prompt text and multi-modal data together. In addition, return whether prompt replacements have been applied. """ processor_data, passthrough_data = self._get_hf_mm_data(mm_items) processor_data["return_mm_token_type_ids"] = True processed_data = self._call_hf_processor( prompt=prompt_text, mm_data=processor_data, mm_kwargs=hf_processor_mm_kwargs, tok_kwargs=tokenization_kwargs, ) processed_data.update(passthrough_data) prompt_ids, = processed_data.pop("input_ids").tolist() mm_token_type_ids = processed_data.pop( "mm_token_type_ids" ) if "mm_token_type_ids" in processed_data else processed_data.pop( "token_type_ids") # for gemma3 only return prompt_ids, processed_data, mm_token_type_ids def apply( self, prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. Apply HF Processor on prompt text and multi-modal data together, outputting token IDs and processed tensors. """ if tokenization_kwargs is None: tokenization_kwargs = {} mm_items = self._to_mm_items(mm_data) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if not isinstance(prompt, str): # the prompt is the tokenized ids which is not supported # by the hf_processor, which is why we would need to decode the ids # into string prompt = hf_processor.decode(prompt) (prompt_ids, processed_data, mm_token_type_ids) = self._apply_hf_processor_text_mm( prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) # HF processor will return `mm_token_type_ids` from which # we can infer mm_placeholders. Until then hardcode to make code run # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 mm_positions = torch.where(mm_token_type_ids == 1)[1] images = mm_items.get_items("image", ImageProcessorItems) multimodal_config = self.info.ctx.model_config.multimodal_config mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} image_sizes = [] for item_idx in range(len(images)): image_size = images.get_image_size(item_idx) image_sizes.append((image_size.height, image_size.width)) mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( image_sizes=image_sizes, **mm_processor_kwargs) mm_placeholders = {} split_sizes = mm_tokens_per_modality["num_image_tokens"] if split_sizes: chunked_mm_positions = torch.split(mm_positions, split_sizes) mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] chunked_mm_tokens = torch.split(mm_tokens, split_sizes) ranges = [ PlaceholderRange( offset=positions[0].item(), length=positions.shape[0], is_embed=(mm_tokens == hf_processor.image_token_id).bool()) for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) ] mm_placeholders = {"image": ranges} num_image_patches = torch.tensor( mm_tokens_per_modality["num_image_patches"] ) if "num_image_patches" in mm_tokens_per_modality else None processed_data['num_image_patches'] = num_image_patches mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, num_image_patches), ) # Use overrides if provided; fallback to data-dependent hashing. mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids) return MultiModalInputs( type="multimodal", prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, mm_placeholders=mm_placeholders, ) class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() logger.info("Using Transformers backend.") self.config: PretrainedConfig = vllm_config.model_config.hf_config self.text_config: PretrainedConfig = self.config.get_text_config() self.cache_config: CacheConfig = vllm_config.cache_config self.device_config: DeviceConfig = vllm_config.device_config self.model_config: ModelConfig = vllm_config.model_config self.parallel_config: ParallelConfig = vllm_config.parallel_config self.quant_config: Optional[ QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size self.pp_rank = self.pp_group.rank_in_group self.tp_size = get_tensor_model_parallel_world_size() # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] """Skip loading weights whose qualname contains these substrings.""" self.ignore_unexpected_prefixes: list[str] = [] """Ignore unexpected weights whose qualname starts with these prefixes. """ self.ignore_unexpected_suffixes: list[str] = [] """Ignore unexpected weights whose qualname ends with these suffixes.""" if self.quant_config: quant_method_name = self.quant_config.get_name() # Check for unsupported quantization methods. if quant_method_name == "mxfp4": raise NotImplementedError("Transformers backend does not " "support MXFP4 quantization yet.") # Skip loading extra bias for GPTQ models. if "gptq" in quant_method_name: self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( self.config, torch_dtype=self.model_config.dtype, trust_remote_code=self.model_config.trust_remote_code, ) # Remove layers not on this pipeline parallel rank self.pipeline_parallel() # Substitute remaining layers with vLLM's layers as needed self.recursive_replace() # Create attention instances for KV cache allocation self.attention_instances = self.create_attention_instances() # Input embeddings if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): names = ("embedding_size", "hidden_size") embedding_dim = getattr_iter(self.text_config, names, None) assert embedding_dim is not None self.model.set_input_embeddings( VocabParallelEmbedding( self.text_config.vocab_size, embedding_dim=embedding_dim, org_num_embeddings=self.text_config.vocab_size, quant_config=self.quant_config, )) # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) # Pipeline parallel intermediate tensors self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states"], self.text_config.hidden_size)) def pipeline_parallel(self): """ Apply the model's pipeline parallelization plan. """ if self.pp_size <= 1: return if not self.model.supports_pp_plan: tip = get_feature_request_tip(self.model_config.model, self.model_config.trust_remote_code) raise ValueError( f"{type(self.model)} does not support pipeline parallel. {tip}" ) module_lists = [] module_list_idx = None pp_plan = list(self.model._pp_plan.keys()) for i, name in enumerate(pp_plan): if isinstance(getattr(self.model, name), nn.ModuleList): module_lists.append(name) module_list_idx = i if len(module_lists) > 1: raise ValueError( "Pipeline parallel of models with multiple `ModuleList`s " "in the base model are not supported yet!") if module_list_idx is None: raise ValueError( f"Could not find `ModuleList` in {type(self.model)}") # Layers before module list for name in pp_plan[:module_list_idx]: if self.pp_group.is_first_rank or ( self.text_config.tie_word_embeddings and self.pp_group.is_last_rank): continue setattr(self.model, name, PPMissingLayer()) # Module list start_layer, end_layer = get_pp_indices( self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) for i in range(len(layers)): if start_layer <= i and i < end_layer: continue layers[i] = PPMissingLayer() # Layers after module list for name in pp_plan[module_list_idx + 1:]: # Modules that should be on last rank if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) def recursive_replace(self): """Recursively replace modules in the model as needed. Currently, this replaces: - `nn.Linear` with vLLM's tensor parallel linear classes - `*RMSNorm` with vLLM's `RMSNorm` """ tp_plan = self.model.tp_plan if not tp_plan and self.tp_size > 1: tip = get_feature_request_tip(self.model_config.model, self.model_config.trust_remote_code) raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}") # Prefix the patterns because we always start from `self.model` tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): new_module = child_module qual_name = maybe_prefix(prefix, child_name) if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) # Some weight loaders expect all linear layers to inherit # LinearBase, so we set a default style which causes any # unspecified layers to be replaced with ReplicatedLinear style = tp_plan.get(pattern, "replicate") new_module = replace_linear_class(child_module, style, self.quant_config, prefix=qual_name) # TODO(hmellor): Enable RMSNorm replacement once we have a way # to choose RMSNorm vs GemmaRMSNorm # elif child_module.__class__.__name__.endswith("RMSNorm"): # new_module = replace_rms_norm_class( # child_module, self.config.hidden_size) else: _recursive_replace(child_module, prefix=qual_name) if new_module is not child_module: setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) _recursive_replace(self.model, prefix="model") def create_attention_instances( self, attn_type: AttentionType = AttentionType.DECODER ) -> dict[int, Attention]: """ Create `Attention` instances to inform KV cache allocation. """ num_heads = self.model_config.get_num_attention_heads( self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) start, end = get_pp_indices(self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention per_layer_sliding_window = None if (hasattr(self.config, "layer_types") and self.config.layer_types[i] == "sliding_attention"): per_layer_sliding_window = self.config.sliding_window attention_instances[i] = Attention( num_heads=num_heads, head_size=head_size, # NOTE: We use Llama scale as default, if it's set by # Transformers, it's updated in vllm_flash_attention_forward scale=head_size**-0.5, num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn", attn_type=attn_type) return attention_instances def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None): """ If a `parameter` is on the `meta` device, then its parent `module` is the original module created by: ```python with torch.device("meta"): self.model: PreTrainedModel = AutoModel.from_config(...) ``` """ def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]): for name, param in module.named_parameters(recurse=False): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like( param.data, dtype=dtype or self.model_config.dtype, device=self.device_config.device, )) setattr(module, name, new_param) for child in module.children(): _init_parameters(child, dtype) _init_parameters(module, dtype) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if not get_pp_group().is_first_rank: assert intermediate_tensors is not None input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] if input_ids is not None: input_ids = input_ids[None, ...] if inputs_embeds is not None: inputs_embeds = inputs_embeds[None, ...] if self.model_config.uses_mrope: position_ids = positions[:, None] else: position_ids = positions[None, ...] hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, position_ids=position_ids, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=self.skip_prefixes, skip_substrs=self.skip_substrs, ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def check_version(self, min_version: str, feature: str): installed = Version(transformers.__version__) required = Version(min_version) if installed < required: raise ImportError( f"Transformers backend requires transformers>={required} " f"for {feature}, but got {installed}") @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # Tell `TransformersBase.load_weights` to skip # `lm_head` if the model has tied word embeddings if self.text_config.tie_word_embeddings: self.skip_prefixes.append("lm_head.") if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( self.text_config.vocab_size, self.text_config.hidden_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if self.text_config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.get_input_embeddings()) logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale) else: self.lm_head = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings()(input_ids) def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states) return logits def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: """Flatten until a list of tensors can be concatenated then do concat""" def _can_concat(x: list[torch.Tensor]): return len(set(map(lambda _x: _x.shape[1:], x))) == 1 if _can_concat(x): return torch.concat(x) return flatten_and_concat(flatten_bn(x)) @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder) @support_torch_compile( # set `positions` to last dim to support Qwen-mrope dynamic_arg_dims={ "input_ids": 0, "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, }, enable_if=can_enable_torch_compile) class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): merge_by_field_config = True # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "language_model.model": "model.language_model", "text_model.model": "model.text_model", "vision_tower": "model.vision_tower", "vqmodel": "model.vqmodel", "visual": "model.visual", "vision_model": "model.vision_model", "vision_embed_tokens": "model.vision_embed_tokens", "image_newline": "model.image_newline", "multi_modal_projector": "model.multi_modal_projector", "text_model.lm_head": "lm_head", "language_model.lm_head": "lm_head", # Qwen models used "model" as the name for the language model. # Therefore, we must map each of submodule explicitly to avoid # conflicts with newer models that use "model.language_model". "model.embed_tokens": "model.language_model.embed_tokens", "model.layers": "model.language_model.layers", "model.norm": "model.language_model.norm", }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) self.dtype = vllm_config.model_config.dtype def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output def get_language_model(self) -> torch.nn.Module: return self.model def get_multimodal_embeddings(self, **kwargs): pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None) image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None) # Model might use `image_patches` instead of `pixel_values` if pixel_values is None: pixel_values = kwargs.pop("image_patches", None) if image_embeds is not None: return image_embeds if pixel_values is None: return None num_image_patches = kwargs.pop("num_image_patches") if pixel_values is not None: vision_embeddings = self.model.get_image_features( pixel_values, **kwargs) if isinstance(vision_embeddings, torch.Tensor): if isinstance(num_image_patches, list): num_image_patches = torch.cat(num_image_patches) if vision_embeddings.ndim == 2: vision_embeddings = vision_embeddings.unsqueeze(0) # Embeddings have to be 2D tensors of length `num_images` # but transformers returns concat tensors if each patch # is of different size. We split it back to make vLLM happy vision_embeddings = torch.split( vision_embeddings, num_image_patches.flatten().tolist()) vision_embeddings = [ embed.flatten(start_dim=0, end_dim=-2) for embed in vision_embeddings ] return vision_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, *, is_multimodal: Optional[torch.Tensor] = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: """ Apply token embeddings to `input_ids`. If `multimodal_embeddings` is passed, scatter them into `input_ids` according to the mask `is_multimodal`. In case the multi-modal token IDs exceed the vocabulary size of the language model, you can set `handle_oov_mm_token=False` to avoid calling the language model's `get_input_embeddings` method on those tokens. """ from .utils import _merge_multimodal_embeddings inputs_embeds = self._get_text_embeddings( input_ids, self.model.get_input_embeddings(), is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds if is_multimodal is None: raise ValueError( "`get_input_embeddings` now requires `is_multimodal` arg, " "please update your model runner according to " "https://github.com/vllm-project/vllm/pull/16229.") return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, )