"vscode:/vscode.git/clone" did not exist on "b706d898af7c55dc854858bace3c9041cf22da66"
chat_utils.py 56.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import inspect
6
import json
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict, deque
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from pathlib import Path
12
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
13

14
15
16
import jinja2
import jinja2.ext
import jinja2.meta
17
import jinja2.nodes
18
19
import jinja2.parser
import jinja2.sandbox
20
import transformers.utils.chat_template_utils as hf_chat_utils
21
from openai.types.chat import (
22
23
24
25
26
27
28
29
30
31
32
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
33
from openai.types.chat import (
34
35
36
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
37
from openai.types.responses import ResponseInputImageParam
38
from openai_harmony import Message as OpenAIHarmonyMessage
39
40
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
41
42
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin

43
# pydantic needs the TypedDict from typing_extensions
44
from typing_extensions import Required, TypedDict
45

46
from vllm import envs
47
from vllm.config import ModelConfig
48
from vllm.logger import init_logger
49
from vllm.model_executor.models import SupportsMultiModal
50
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
51
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
52
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
53
from vllm.transformers_utils.processor import cached_get_processor
54
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
55
from vllm.utils import random_uuid
56
from vllm.utils.func_utils import supports_kw
57
58
59

logger = init_logger(__name__)

60
61
62
63
64
65
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


81
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
82
    image_embeds: str | dict[str, str] | None
83
84
85
86
87
88
89
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
90
    uuid: str | None
91
92
93
94
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
95
96


97
98
99
100
101
102
103
104
105
106
107
108
109
110
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


111
112
113
114
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
115

116
117
118
119
120
121
122
123
124
125
126
127
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
128

129
130
    image_pil: PILImage | None
    uuid: str | None
131
132
133
134
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
135
136


137
138
139
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
140

141
142
143
144
145
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
146

147
148
    image_url: str | None
    uuid: str | None
149
150
151
152
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
153
154
155
156


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
157

158
159
160
161
162
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
163

164
    audio_url: str | None
165
166


167
168
169
170
171
172
173
174
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
175

176
177
    video_url: str | None
    uuid: str | None
178
179
180
181
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
182
183


Julien Denize's avatar
Julien Denize committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


205
206
207
208
209
210
211
212
213
214
215
216
217
218
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
219
220
221
222


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
223

224
225
226
    role: Required[str]
    """The role of the message's author."""

227
    content: str | list[ChatCompletionContentPartParam]
228
229
230
231
232
233
234
235
236
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

237
    tool_call_id: str | None
238
239
    """Tool call that this message is responding to."""

240
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
241
242
    """The tool calls generated by the model, such as function calls."""

243

244
245
246
247
248
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
249
250


251
# TODO: Make fields ReadOnly once mypy supports it
252
253
254
255
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

256
    content: str | None | list[dict[str, str]]
257
258
    """The contents of the message"""

259
    tool_call_id: str | None
260
261
    """Tool call that this message is responding to."""

262
    name: str | None
263
264
    """The name of the function to call"""

265
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
266
    """The tool calls generated by the model, such as function calls."""
267
268


269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
285
286
287
288
289
        return (
            _is_var_access(node.node, varname)
            and isinstance(node.arg, jinja2.nodes.Const)
            and node.arg.value == key
        )
290
291
292
293
294
295
296
297
298
299

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
300
    key: str | None = None,
301
302
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
303
        return node.node is not None and _is_var_or_elems_access(
304
305
            node.node, varname, key
        )
306
307
308
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

309
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
310
311
        node.arg, jinja2.nodes.Slice
    ):
312
313
        return _is_var_or_elems_access(node.node, varname, key)

314
    return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
343
        varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


375
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
376
377
378
379
380
381
382
383
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


384
@lru_cache(maxsize=32)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


405
def resolve_mistral_chat_template(
406
    chat_template: str | None,
407
    **kwargs: Any,
408
) -> str | None:
409
410
411
412
    if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
        raise ValueError(
            "'chat_template' or 'chat_template_kwargs' cannot be overridden "
            "for mistral tokenizer."
413
        )
414

415
416
    return None

417

418
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
419
420
421
422
423
424
425
426
427
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
428
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
429
    model_config: ModelConfig,
430
) -> str | None:
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
            trust_remote_code=model_config.trust_remote_code,
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


463
def resolve_hf_chat_template(
464
465
466
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
467
468
    *,
    model_config: ModelConfig,
469
) -> str | None:
470
471
472
473
474
475
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
476
        chat_template = _try_get_processor_chat_template(tokenizer, model_config)
477
478
        if chat_template is not None:
            return chat_template
479
480
481
482
483

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
484
485
486
487
488
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
489

490
491
492
493
494
495
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
496
        logger.info_once(
497
498
499
500
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
501
502
        chat_template = load_chat_template(path)
    else:
503
        logger.debug_once(
504
505
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
506
507

    return chat_template
508
509


510
def _resolve_chat_template_content_format(
511
512
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
513
    tokenizer: AnyTokenizer,
514
515
    *,
    model_config: ModelConfig,
516
517
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
518
        hf_chat_template = resolve_hf_chat_template(
519
520
521
            tokenizer,
            chat_template=chat_template,
            tools=tools,
522
            model_config=model_config,
523
        )
524
    else:
525
526
        hf_chat_template = None

527
528
529
530
531
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
532

533
534
535
536
537
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
538

539
    return detected_format
540
541
542


@lru_cache
543
def _log_chat_template_content_format(
544
    chat_template: str | None,
545
    given_format: ChatTemplateContentFormatOption,
546
547
    detected_format: ChatTemplateContentFormatOption,
):
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

565
566

def resolve_chat_template_content_format(
567
568
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
569
570
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
571
572
    *,
    model_config: ModelConfig,
573
) -> _ChatTemplateContentFormat:
574
575
576
    if given_format != "auto":
        return given_format

577
578
579
580
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
581
        model_config=model_config,
582
583
584
585
586
587
588
589
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

590
    return detected_format
591

592

593
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
594
595
596
597
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
598
599
600
601
602
603
604
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
605
606
        super().__init__()

607
608
        self._model_config = model_config
        self._tokenizer = tokenizer
609

610
611
        self._items_by_modality = defaultdict[str, list[_T | None]](list)
        self._uuids_by_modality = defaultdict[str, list[str | None]](list)
612

613
614
615
616
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

617
    @cached_property
618
    def model_cls(self) -> type[SupportsMultiModal]:
619
        from vllm.model_executor.model_loader import get_model_cls
620

621
622
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
623

624
625
626
627
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

628
629
630
631
    @property
    def allowed_media_domains(self):
        return self._model_config.allowed_media_domains

632
633
634
635
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

636
637
638
639
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

640
    def add(
641
642
        self,
        modality: ModalityStr,
643
644
645
        item: _T | None,
        uuid: str | None = None,
    ) -> str | None:
646
647
648
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
649
650

        An optional uuid can be added which serves as a unique identifier of the
651
        media.
652
        """
653
        input_modality = modality.replace("_embeds", "")
654
        num_items = len(self._items_by_modality[modality]) + 1
655

656
        self.mm_processor.validate_num_items(input_modality, num_items)
657

658
        self._items_by_modality[modality].append(item)
659
        self._uuids_by_modality[modality].append(uuid)
660

661
        return self.model_cls.get_placeholder_str(modality, num_items)
662

663
    def all_mm_uuids(self) -> MultiModalUUIDDict | None:
664
665
666
667
668
        if not self._items_by_modality:
            return None
        mm_uuids = {}
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
669
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
670
671
672
673

        if "image_embeds" in uuids_by_modality:
            image_embeds_uuids = uuids_by_modality["image_embeds"]
            if len(image_embeds_uuids) > 1:
674
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
675
676
677
678
679
680
681
682
683
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
        return mm_uuids

684
685
686
687
688
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


689
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
690
    def all_mm_data(self) -> MultiModalDataDict | None:
691
692
693
694
695
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
696
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
697
698
699
700

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
701
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
702
            mm_inputs["image"] = image_embeds_lst[0]
703
        if "image" in items_by_modality:
704
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
705
        if "audio" in items_by_modality:
706
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
707
        if "video" in items_by_modality:
708
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
709
        return mm_inputs
710
711
712
713
714

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


715
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
716
    async def all_mm_data(self) -> MultiModalDataDict | None:
717
718
719
        if not self._items_by_modality:
            return None
        mm_inputs = {}
720
721
722
723
724
725
726
727
728
        items_by_modality = {}
        for modality, items in self._items_by_modality.items():
            coros = []
            for item in items:
                if item is not None:
                    coros.append(item)
                else:
                    coros.append(asyncio.sleep(0))
            items_by_modality[modality] = await asyncio.gather(*coros)
729

730
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
731
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
732
733
734
735

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
736
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
737
            mm_inputs["image"] = image_embeds_lst[0]
738
        if "image" in items_by_modality:
739
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
740
        if "audio" in items_by_modality:
741
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
742
        if "video" in items_by_modality:
743
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
744
        return mm_inputs
745
746
747
748
749
750
751
752
753

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

754
        # stores model placeholders list with corresponding
755
756
757
758
759
760
761
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

762
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
763
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
764
        if placeholder:
765
            self._placeholder_storage[mod_placeholder].append(placeholder)
766

767
768
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
769
770

    @abstractmethod
771
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
772
773
        raise NotImplementedError

774
    @abstractmethod
775
    def parse_image_embeds(
776
        self,
777
778
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
779
    ) -> None:
780
781
        raise NotImplementedError

782
    @abstractmethod
783
    def parse_image_pil(
784
        self, image_pil: Image.Image | None, uuid: str | None = None
785
    ) -> None:
786
787
        raise NotImplementedError

788
    @abstractmethod
789
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
790
791
        raise NotImplementedError

792
    @abstractmethod
793
    def parse_input_audio(
794
        self, input_audio: InputAudio | None, uuid: str | None = None
795
    ) -> None:
796
797
        raise NotImplementedError

798
    @abstractmethod
799
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
800
801
        raise NotImplementedError

802
803
804
805
806
807

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
808
809
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
810
811
812

        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
813
            media_io_kwargs=media_io_kwargs,
814
            allowed_local_media_path=tracker.allowed_local_media_path,
815
            allowed_media_domains=tracker.allowed_media_domains,
816
817
        )

818
819
820
821
    @property
    def model_config(self) -> ModelConfig:
        return self._tracker.model_config

822
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
823
        image = self._connector.fetch_image(image_url) if image_url else None
824

825
        placeholder = self._tracker.add("image", image, uuid)
826
        self._add_placeholder("image", placeholder)
827

828
    def parse_image_embeds(
829
        self,
830
831
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
832
    ) -> None:
833
834
835
836
837
838
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

839
840
841
842
843
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
844
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
845
846
847

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
848
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
849

850
851
852
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

853
        self._add_placeholder("image", placeholder)
854

855
    def parse_image_pil(
856
        self, image_pil: Image.Image | None, uuid: str | None = None
857
858
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
859
        self._add_placeholder("image", placeholder)
860

861
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
862
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
863

864
        placeholder = self._tracker.add("audio", audio, uuid)
865
        self._add_placeholder("audio", placeholder)
866

867
    def parse_input_audio(
868
        self, input_audio: InputAudio | None, uuid: str | None = None
869
    ) -> None:
870
871
872
873
874
875
876
877
878
879
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
880

881
        return self.parse_audio(audio_url, uuid)
882

883
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
884
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
885

886
        placeholder = self._tracker.add("video", video, uuid)
887
        self._add_placeholder("video", placeholder)
888

889
890
891
892
893
894

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
895
896
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
897
898
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
899
            media_io_kwargs=media_io_kwargs,
900
            allowed_local_media_path=tracker.allowed_local_media_path,
901
            allowed_media_domains=tracker.allowed_media_domains,
902
        )
903

904
905
906
907
    @property
    def model_config(self) -> ModelConfig:
        return self._tracker.model_config

908
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
909
        image_coro = self._connector.fetch_image_async(image_url) if image_url else None
910

911
        placeholder = self._tracker.add("image", image_coro, uuid)
912
        self._add_placeholder("image", placeholder)
913

914
    def parse_image_embeds(
915
        self,
916
917
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
918
    ) -> None:
919
920
921
922
923
924
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

925
        future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
926
927
928
929
930
931
932
933
934

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
935
            embedding = self._connector.fetch_image_embedding(image_embeds)
936
937
            future.set_result(embedding)

938
939
940
        if image_embeds is None:
            future.set_result(None)

941
        placeholder = self._tracker.add("image_embeds", future, uuid)
942
        self._add_placeholder("image", placeholder)
943

944
    def parse_image_pil(
945
        self, image_pil: Image.Image | None, uuid: str | None = None
946
    ) -> None:
947
        future: asyncio.Future[Image.Image | None] = asyncio.Future()
948
949
950
951
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
952

953
        placeholder = self._tracker.add("image", future, uuid)
954
        self._add_placeholder("image", placeholder)
955

956
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
957
        audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
958

959
        placeholder = self._tracker.add("audio", audio_coro, uuid)
960
        self._add_placeholder("audio", placeholder)
961

962
    def parse_input_audio(
963
        self, input_audio: InputAudio | None, uuid: str | None = None
964
    ) -> None:
965
966
967
968
969
970
971
972
973
974
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
975

976
        return self.parse_audio(audio_url, uuid)
977

978
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
979
980
981
982
983
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
984

985
        placeholder = self._tracker.add("video", video, uuid)
986
        self._add_placeholder("video", placeholder)
987

988

989
def validate_chat_template(chat_template: Path | str | None):
990
991
992
993
994
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
995
        raise FileNotFoundError("the supplied chat template path doesn't exist")
996
997
998

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
999
1000
1001
1002
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1003
1004
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
1005
1006
                f"appears path-like, but doesn't exist!"
            )
1007
1008

    else:
1009
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1010
1011


1012
def _load_chat_template(
1013
    chat_template: Path | str | None,
1014
1015
    *,
    is_literal: bool = False,
1016
) -> str | None:
1017
1018
    if chat_template is None:
        return None
1019
1020
1021

    if is_literal:
        if isinstance(chat_template, Path):
1022
1023
1024
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1025

1026
        return chat_template
1027

1028
    try:
1029
        with open(chat_template) as f:
1030
            return f.read()
1031
    except OSError as e:
1032
1033
1034
        if isinstance(chat_template, Path):
            raise

1035
1036
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1037
1038
1039
1040
1041
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
1042
            raise ValueError(msg) from e
1043

1044
1045
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1046
1047
1048
1049
1050
1051
1052
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1053
    chat_template: Path | str | None,
1054
1055
    *,
    is_literal: bool = False,
1056
) -> str | None:
1057
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1058
1059


1060
1061
1062
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1063
1064
1065
1066
1067
1068
1069
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1070
# TODO: Let user specify how to insert multimodal tokens into prompt
1071
# (similar to chat template)
1072
1073
1074
1075
1076
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1077
    """Combine multimodal prompts for a multimodal language model."""
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1096
    # Look through the text prompt to check for missing placeholders
1097
    missing_placeholders: list[str] = []
1098
1099
1100
1101
1102
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1103
1104
1105
1106
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1107
1108
                "when manually placing image placeholders.",
                interleave_strings,
1109
1110
            )
            logger.debug("Input prompt: %s", text_prompt)
1111
1112
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1113
1114
                "actual multimodal data items."
            )
1115

1116
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1117

1118
1119
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1120
    return "\n".join(missing_placeholders + [text_prompt])
1121
1122


1123
1124
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1125
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1126
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1127
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1128
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1129
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1130
1131
1132
1133
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1134

1135
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1136
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1137

1138
# Define a mapping from part types to their corresponding parsing functions.
1139
MM_PARSER_MAP: dict[
1140
1141
1142
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1143
1144
1145
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1146
1147
1148
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1149
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1150
1151
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1152
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1153
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1154
1155
1156
1157
}


def _parse_chat_message_content_mm_part(
1158
1159
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1160
    """
1161
    Parses a given multi-modal content part based on its type.
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1175
1176
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1177
    part_type = part.get("type", None)
1178
    uuid = part.get("uuid", None)
1179

1180
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1181
1182
1183
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1184
1185
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1186
            logger.warning(
1187
                "'image_url.detail' is currently not supported and will be ignored."
1188
            )
1189
1190
1191
1192

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1193
    # 'type' is required field by pydantic
1194
1195
    if part_type is None or uuid is not None:
        if "image_url" in part:
1196
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1197
1198
1199
1200
1201
1202
1203
1204
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1205
            image_params = cast(  # type: ignore
1206
1207
1208
1209
1210
1211
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1212
            image_params = cast(  # type: ignore
1213
1214
1215
1216
1217
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
        if "audio_url" in part:
1218
            audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
1219
1220
1221
1222
1223
1224
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1225
        if part.get("input_audio") is not None:
1226
            input_audio_params = cast(dict[str, str], part)
1227
            return "input_audio", input_audio_params
1228
        if "video_url" in part:
1229
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1230
1231
1232
1233
1234
1235
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1236
1237
1238
1239
1240
1241
1242
1243
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1244
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1245
1246
1247
    "text",
    "refusal",
)
1248

1249

1250
1251
1252
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1253
    mm_tracker: BaseMultiModalItemTracker,
1254
1255
    *,
    wrap_dicts: bool,
1256
    interleave_strings: bool,
1257
) -> list[ConversationMessage]:
1258
    content = list[_ContentPart]()
1259

1260
    mm_parser = mm_tracker.create_parser()
1261
1262

    for part in parts:
1263
        parse_res = _parse_chat_message_content_part(
1264
1265
1266
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1267
            interleave_strings=interleave_strings,
1268
        )
1269
1270
        if parse_res:
            content.append(parse_res)
1271

1272
    if wrap_dicts:
1273
        # Parsing wraps images and texts as interleaved dictionaries
1274
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1275
    texts = cast(list[str], content)
1276
1277
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1278
1279
1280
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1281
1282
1283
    else:
        text_prompt = "\n".join(texts)

1284
1285
1286
1287
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1288
1289
1290
1291
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1292
    interleave_strings: bool,
1293
) -> _ContentPart | None:
1294
1295
1296
1297
1298
1299
1300
1301
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1302
        return part
1303
1304
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1305
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1306
    # content is None, log a warning and skip
1307
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1308
        logger.warning(
1309
            "Skipping multimodal part '%s' (type: '%s') "
1310
1311
1312
1313
            "with empty / unparsable content.",
            part,
            part_type,
        )
1314
1315
        return None

Julien Denize's avatar
Julien Denize committed
1316
    if part_type in ("text", "input_text", "refusal", "thinking"):
1317
1318
        str_content = cast(str, content)
        if wrap_dicts:
1319
            return {"type": "text", "text": str_content}
1320
1321
        else:
            return str_content
1322

1323
1324
1325
1326
1327
1328
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1329
    modality = None
1330
    if part_type == "image_pil":
1331
        image_content = cast(Image.Image, content) if content is not None else None
1332
        mm_parser.parse_image_pil(image_content, uuid)
1333
        modality = "image"
1334
    elif part_type in ("image_url", "input_image"):
1335
        str_content = cast(str, content)
1336
        mm_parser.parse_image(str_content, uuid)
1337
1338
        modality = "image"
    elif part_type == "image_embeds":
1339
        content = cast(str | dict[str, str], content) if content is not None else None
1340
        mm_parser.parse_image_embeds(content, uuid)
1341
1342
        modality = "image"
    elif part_type == "audio_url":
1343
        str_content = cast(str, content)
1344
        mm_parser.parse_audio(str_content, uuid)
1345
1346
        modality = "audio"
    elif part_type == "input_audio":
1347
        dict_content = cast(InputAudio, content)
1348
        mm_parser.parse_input_audio(dict_content, uuid)
1349
1350
        modality = "audio"
    elif part_type == "video_url":
1351
        str_content = cast(str, content)
1352
        mm_parser.parse_video(str_content, uuid)
1353
1354
1355
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1356

1357
1358
1359
    return (
        {"type": modality}
        if wrap_dicts
1360
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1361
    )
1362
1363


1364
1365
1366
1367
1368
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1369
def _parse_chat_message_content(
1370
1371
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1372
    content_format: _ChatTemplateContentFormat,
1373
    interleave_strings: bool,
1374
) -> list[ConversationMessage]:
1375
1376
1377
1378
    role = message["role"]
    content = message.get("content")

    if content is None:
1379
1380
        content = []
    elif isinstance(content, str):
1381
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1382
    result = _parse_chat_message_content_parts(
1383
1384
        role,
        content,  # type: ignore
1385
        mm_tracker,
1386
        wrap_dicts=(content_format == "openai"),
1387
        interleave_strings=interleave_strings,
1388
    )
1389

1390
    for result_msg in result:
1391
        if role == "assistant":
1392
1393
            parsed_msg = _AssistantParser(message)

1394
1395
1396
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1397
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1409

1410
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1411
1412
1413
1414
1415
1416
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1417
1418
1419
1420
1421
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1422
            for item in message["tool_calls"]:
1423
1424
1425
1426
1427
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
                    item["function"]["arguments"] = json.loads(content)
                else:
                    item["function"]["arguments"] = {}
1428
1429


1430
def parse_chat_messages(
1431
    messages: list[ChatCompletionMessageParam],
1432
    model_config: ModelConfig,
1433
    tokenizer: AnyTokenizer,
1434
    content_format: _ChatTemplateContentFormat,
1435
1436
) -> tuple[
    list[ConversationMessage],
1437
1438
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1439
]:
1440
    conversation: list[ConversationMessage] = []
1441
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1442
1443

    for msg in messages:
1444
1445
1446
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1447
            content_format,
1448
1449
1450
1451
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1452
            ),
1453
        )
1454

1455
        conversation.extend(sub_messages)
1456

1457
1458
    _postprocess_messages(conversation)

1459
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1460
1461


1462
def parse_chat_messages_futures(
1463
    messages: list[ChatCompletionMessageParam],
1464
1465
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1466
    content_format: _ChatTemplateContentFormat,
1467
1468
) -> tuple[
    list[ConversationMessage],
1469
1470
    Awaitable[MultiModalDataDict | None],
    MultiModalUUIDDict | None,
1471
]:
1472
    conversation: list[ConversationMessage] = []
1473
1474
1475
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1476
1477
1478
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1479
            content_format,
1480
1481
1482
1483
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1484
            ),
1485
        )
1486
1487
1488

        conversation.extend(sub_messages)

1489
1490
    _postprocess_messages(conversation)

1491
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1492
1493


1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
def _resolve_chat_template_kwargs(
    chat_template: str,
):
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
    return template_vars


_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)


1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
    # Get standard parameters from HuggingFace's base tokenizer class.
    # This dynamically extracts parameters from PreTrainedTokenizer's
    # apply_chat_template method, ensuring compatibility with tokenizers
    # that use **kwargs to receive standard parameters.

    # Read signature from HF's base class - the single source of truth
    base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
    # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
    return frozenset(
        p.name
        for p in base_sig.parameters.values()
        if p.kind
        not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
    )


1541
def resolve_chat_template_kwargs(
1542
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
1543
1544
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
1545
    raise_on_unexpected: bool = True,
1546
) -> dict[str, Any]:
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template", "tokenize"}
    if raise_on_unexpected and (
        unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
    ):
        raise ValueError(
            "Found unexpected chat template kwargs from request: "
            f"{unexpected_in_kwargs}"
        )

1558
    fn_kw = {
1559
1560
        k
        for k in chat_template_kwargs
1561
1562
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }
1563
    template_vars = _cached_resolve_chat_template_kwargs(chat_template)
1564
1565
1566
1567
1568

    # Allow standard HF parameters even if tokenizer uses **kwargs to receive them
    hf_base_params = _get_hf_base_chat_template_params()

    accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
1569
    return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
1570
1571


1572
def apply_hf_chat_template(
1573
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
1574
    conversation: list[ConversationMessage],
1575
1576
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
1577
    *,
1578
    model_config: ModelConfig,
1579
    **kwargs: Any,
1580
) -> str:
1581
    hf_chat_template = resolve_hf_chat_template(
1582
1583
1584
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1585
        model_config=model_config,
1586
    )
1587

1588
    if hf_chat_template is None:
1589
1590
1591
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1592
1593
            "does not define one."
        )
1594

1595
1596
1597
1598
1599
1600
    resolved_kwargs = resolve_chat_template_kwargs(
        tokenizer=tokenizer,
        chat_template=hf_chat_template,
        chat_template_kwargs=kwargs,
    )

1601
1602
1603
1604
1605
    try:
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
1606
            tokenize=False,
1607
            **resolved_kwargs,
1608
        )
1609

1610
1611
1612
1613
1614
1615
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1616
1617
            "An error occurred in `transformers` while applying chat template"
        )
1618
        raise ValueError(str(e)) from e
1619

1620

1621
1622
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1623
    messages: list[ChatCompletionMessageParam],
1624
1625
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
1626
    **kwargs: Any,
1627
) -> list[int]:
1628
1629
    from mistral_common.exceptions import MistralCommonException

1630
1631
1632
1633
1634
1635
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1636

1637
1638
1639
1640
1641
1642
1643
1644
1645
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1646
    # properly caught in the preprocessing_input step
1647
    except (AssertionError, MistralCommonException) as e:
1648
        raise ValueError(str(e)) from e
1649
1650
1651
1652
1653
1654
1655

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1656
1657
            "An error occurred in `mistral_common` while applying chat template"
        )
1658
        raise ValueError(str(e)) from e
1659

1660

1661
1662
1663
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1664
1665
1666
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1667
1668
1669
    return idx


1670
1671
1672
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1673
1674
1675
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"