# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import json from abc import ABC, abstractmethod from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast) import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, ChatCompletionContentPartImageParam, ChatCompletionContentPartInputAudioParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) from openai.types.chat import (ChatCompletionContentPartRefusalParam, ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam) from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) from openai.types.responses import ResponseInputImageParam from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter # yapf: enable from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin) # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.utils import MediaConnector # yapf: disable from vllm.transformers_utils.chat_templates import ( get_chat_template_fallback_path) # yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import deprecate_kwargs, random_uuid logger = init_logger(__name__) MODALITY_PLACEHOLDERS_MAP = { "image": "<##IMAGE##>", "audio": "<##AUDIO##>", "video": "<##VIDEO##>", } class AudioURL(TypedDict, total=False): url: Required[str] """ Either a URL of the audio or a data URL with base64 encoded audio data. """ class ChatCompletionContentPartAudioParam(TypedDict, total=False): audio_url: Required[AudioURL] type: Required[Literal["audio_url"]] """The type of the content part.""" class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): image_embeds: Required[Union[str, dict[str, str]]] """ The image embeddings. It can be either: - A single base64 string. - A dictionary where each value is a base64 string. """ type: Required[Literal["image_embeds"]] """The type of the content part.""" class VideoURL(TypedDict, total=False): url: Required[str] """ Either a URL of the video or a data URL with base64 encoded video data. """ class ChatCompletionContentPartVideoParam(TypedDict, total=False): video_url: Required[VideoURL] type: Required[Literal["video_url"]] """The type of the content part.""" class PILImage(BaseModel): """ A PIL.Image.Image object. """ image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) class CustomChatCompletionContentPILImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a PIL image. Example: { "image_pil": ImageAsset('cherry_blossom').pil_image } """ image_pil: Required[PILImage] class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain image_url. This is supported by OpenAI API, although it is not documented. Example: { "image_url": "https://example.com/image.jpg" } """ image_url: Required[str] class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. Example: { "audio_url": "https://example.com/audio.mp3" } """ audio_url: Required[str] class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. Example: { "video_url": "https://example.com/video.mp4" } """ video_url: Required[str] ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPILImageParam, CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleVideoParam, str] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" role: Required[str] """The role of the message's author.""" content: Union[str, list[ChatCompletionContentPartParam]] """The contents of the message.""" name: str """An optional name for the participant. Provides the model information to differentiate between participants of the same role. """ tool_call_id: Optional[str] """Tool call that this message is responding to.""" tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls.""" ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] # TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" content: Union[Optional[str], list[dict[str, str]]] """The contents of the message""" tool_call_id: Optional[str] """Tool call that this message is responding to.""" name: Optional[str] """The name of the function to call""" tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls.""" # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] # Used internally _ChatTemplateContentFormat = Literal["string", "openai"] def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: if isinstance(node, jinja2.nodes.Name): return node.ctx == "load" and node.name == varname return False def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: if isinstance(node, jinja2.nodes.Getitem): return (_is_var_access(node.node, varname) and isinstance(node.arg, jinja2.nodes.Const) and node.arg.value == key) if isinstance(node, jinja2.nodes.Getattr): return _is_var_access(node.node, varname) and node.attr == key return False def _is_var_or_elems_access( node: jinja2.nodes.Node, varname: str, key: Optional[str] = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): return (node.node is not None and _is_var_or_elems_access(node.node, varname, key)) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) if (isinstance(node, jinja2.nodes.Getitem) and isinstance(node.arg, jinja2.nodes.Slice)): return _is_var_or_elems_access(node.node, varname, key) # yapf: disable return ( _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) ) # yapf: enable def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): # Global variable that is implicitly defined at the root yield root, varname # Iterative BFS related_varnames = deque([varname]) while related_varnames: related_varname = related_varnames.popleft() for assign_ast in root.find_all(jinja2.nodes.Assign): lhs = assign_ast.target rhs = assign_ast.node if _is_var_or_elems_access(rhs, related_varname): assert isinstance(lhs, jinja2.nodes.Name) yield assign_ast, lhs.name # Avoid infinite looping for self-assignment if lhs.name != related_varname: related_varnames.append(lhs.name) # NOTE: The proper way to handle this is to build a CFG so that we can handle # the scope in which each variable is defined, but that is too complicated def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): messages_varnames = [ varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") ] # Search for {%- for message in messages -%} loops for loop_ast in root.find_all(jinja2.nodes.For): loop_iter = loop_ast.iter loop_target = loop_ast.target for varname in messages_varnames: if _is_var_or_elems_access(loop_iter, varname): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name break def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): message_varnames = [ varname for _, varname in _iter_nodes_assign_messages_item(root) ] # Search for {%- for content in message['content'] -%} loops for loop_ast in root.find_all(jinja2.nodes.For): loop_iter = loop_ast.iter loop_target = loop_ast.target for varname in message_varnames: if _is_var_or_elems_access(loop_iter, varname, "content"): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name break def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: try: jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) return jinja_compiled.environment.parse(chat_template) except Exception: logger.exception("Error when compiling Jinja template") return None @lru_cache(maxsize=32) def _detect_content_format( chat_template: str, *, default: _ChatTemplateContentFormat, ) -> _ChatTemplateContentFormat: jinja_ast = _try_extract_ast(chat_template) if jinja_ast is None: return default try: next(_iter_nodes_assign_content_item(jinja_ast)) except StopIteration: return "string" except Exception: logger.exception("Error when parsing AST of Jinja template") return default else: return "openai" def resolve_mistral_chat_template( chat_template: Optional[str], **kwargs: Any, ) -> Optional[str]: if chat_template is not None: logger.warning_once( "'chat_template' cannot be overridden for mistral tokenizer.") if "add_generation_prompt" in kwargs: logger.warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " "so it will be ignored.") if "continue_final_message" in kwargs: logger.warning_once( "'continue_final_message' is not supported for mistral tokenizer, " "so it will be ignored.") return None @deprecate_kwargs( "trust_remote_code", additional_message="Please use `model_config.trust_remote_code` instead.", ) def resolve_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], *, model_config: ModelConfig, trust_remote_code: Optional[bool] = None, ) -> Optional[str]: # 1st priority: The given chat template if chat_template is not None: return chat_template # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: try: processor = cached_get_processor( tokenizer.name_or_path, processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin), trust_remote_code=model_config.trust_remote_code, ) if isinstance(processor, ProcessorMixin) and \ hasattr(processor, 'chat_template') and \ processor.chat_template is not None: return processor.chat_template except Exception: logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: logger.debug("Failed to load AutoTokenizer chat template for %s", tokenizer.name_or_path, exc_info=True) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( model_type=model_config.hf_config.model_type, tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: logger.info("Loading chat template fallback for %s as there isn't one " "defined on HF Hub.", tokenizer.name_or_path) chat_template = load_chat_template(path) else: logger.debug("There is no chat template fallback for %s", tokenizer.name_or_path) return chat_template def _resolve_chat_template_content_format( chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], tokenizer: AnyTokenizer, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): hf_chat_template = resolve_hf_chat_template( tokenizer, chat_template=chat_template, tools=tools, model_config=model_config, ) else: hf_chat_template = None jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) else load_chat_template(chat_template, is_literal=True)) detected_format = ("string" if jinja_text is None else _detect_content_format(jinja_text, default="string")) return detected_format @lru_cache def _log_chat_template_content_format( chat_template: Optional[str], given_format: ChatTemplateContentFormatOption, detected_format: ChatTemplateContentFormatOption, ): logger.info( "Detected the chat template content format to be '%s'. " "You can set `--chat-template-content-format` to override this.", detected_format, ) if given_format != "auto" and given_format != detected_format: logger.warning( "You specified `--chat-template-content-format %s` " "which is different from the detected format '%s'. " "If our automatic detection is incorrect, please consider " "opening a GitHub issue so that we can improve it: " "https://github.com/vllm-project/vllm/issues/new/choose", given_format, detected_format, ) @deprecate_kwargs( "trust_remote_code", additional_message="Please use `model_config.trust_remote_code` instead.", ) def resolve_chat_template_content_format( chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, *, model_config: ModelConfig, trust_remote_code: Optional[bool] = None, ) -> _ChatTemplateContentFormat: if given_format != "auto": return given_format detected_format = _resolve_chat_template_content_format( chat_template, tools, tokenizer, model_config=model_config, ) _log_chat_template_content_format( chat_template, given_format=given_format, detected_format=detected_format, ) return detected_format ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number of multi-modal items in a given request does not exceed the configured maximum per prompt. """ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): super().__init__() self._model_config = model_config self._tokenizer = tokenizer self._items_by_modality = defaultdict[str, list[_T]](list) @property def model_config(self) -> ModelConfig: return self._model_config @cached_property def model_cls(self): from vllm.model_executor.model_loader import get_model_cls return get_model_cls(self.model_config) @property def allowed_local_media_path(self): return self._model_config.allowed_local_media_path @property def mm_registry(self): return MULTIMODAL_REGISTRY def add(self, modality: ModalityStr, item: _T) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. """ mm_registry = self.mm_registry model_config = self.model_config model_cls = cast(SupportsMultiModal, self.model_cls) input_modality = modality.replace("_embeds", "") if mm_registry.has_processor(model_config): mm_processor = mm_registry.create_processor(model_config) allowed_counts = mm_processor.info.get_allowed_mm_limits() allowed_count = allowed_counts.get(input_modality, 0) else: mm_config = model_config.multimodal_config if mm_config is None: msg = "This model does not support multi-modal inputs" raise ValueError(msg) allowed_count = mm_config.get_limit_per_prompt(input_modality) current_count = len(self._items_by_modality[modality]) + 1 if current_count > allowed_count: raise ValueError( f"At most {allowed_count} {modality}(s) may be provided in " "one request. You can set `--limit-mm-per-prompt` to " "increase this limit if the model supports it.") self._items_by_modality[modality].append(item) return model_cls.get_placeholder_str(modality, current_count) @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError(\ "Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError(\ "Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = { modality: await asyncio.gather(*items) for modality, items in self._items_by_modality.items() } if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( "Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError( "Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self) class BaseMultiModalContentParser(ABC): def __init__(self) -> None: super().__init__() # stores model placehodlers list with corresponding # general MM placeholder: # { # "<##IMAGE##>": ["", "", ""], # "<##AUDIO##>": ["