chat_utils.py 55.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
from abc import ABC, abstractmethod
7
from collections import Counter, defaultdict, deque
8
from collections.abc import Awaitable, Iterable
9
from functools import cached_property, lru_cache, partial
10
from pathlib import Path
11
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast
12

13
14
15
import jinja2
import jinja2.ext
import jinja2.meta
16
import jinja2.nodes
17
18
import jinja2.parser
import jinja2.sandbox
19
import transformers.utils.chat_template_utils as hf_chat_utils
20
from openai.types.chat import (
21
22
23
24
25
26
27
28
29
30
31
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
32
from openai.types.chat import (
33
34
35
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
36
from openai.types.responses import ResponseInputImageParam
37
from openai_harmony import Message as OpenAIHarmonyMessage
38
39
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
40
41
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin

42
# pydantic needs the TypedDict from typing_extensions
43
from typing_extensions import Required, TypeAlias, TypedDict
44

45
from vllm.config import ModelConfig
46
from vllm.logger import init_logger
47
from vllm.model_executor.models import SupportsMultiModal
48
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
49
from vllm.multimodal.utils import MediaConnector
50
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
51
from vllm.transformers_utils.processor import cached_get_processor
52
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
53
from vllm.utils import random_uuid, supports_kw
54
55
56

logger = init_logger(__name__)

57
58
59
60
61
62
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


78
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
79
    image_embeds: Optional[Union[str, dict[str, str]]]
80
81
82
83
84
85
86
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
87
88
89
90
91
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
92
93


94
95
96
97
98
99
100
101
102
103
104
105
106
107
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


108
109
110
111
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
112

113
114
115
116
117
118
119
120
121
122
123
124
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
125

126
    image_pil: Optional[PILImage]
127
128
129
130
131
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
132
133


134
135
136
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
137

138
139
140
141
142
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
143

144
    image_url: Optional[str]
145
146
147
148
149
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
150
151
152
153


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
154

155
156
157
158
159
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
160

161
    audio_url: Optional[str]
162
163


164
165
166
167
168
169
170
171
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
172

173
    video_url: Optional[str]
174
175
176
177
178
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
179
180


Julien Denize's avatar
Julien Denize committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


202
ChatCompletionContentPartParam: TypeAlias = Union[
203
204
    OpenAIChatCompletionContentPartParam,
    ChatCompletionContentPartAudioParam,
205
    ChatCompletionContentPartInputAudioParam,
206
207
    ChatCompletionContentPartVideoParam,
    ChatCompletionContentPartRefusalParam,
208
    CustomChatCompletionContentPILImageParam,
209
    CustomChatCompletionContentSimpleImageParam,
210
    ChatCompletionContentPartImageEmbedsParam,
211
    CustomChatCompletionContentSimpleAudioParam,
212
213
214
215
    CustomChatCompletionContentSimpleVideoParam,
    str,
    CustomThinkCompletionContentParam,
]
216
217
218
219


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
220

221
222
223
    role: Required[str]
    """The role of the message's author."""

224
    content: Union[str, list[ChatCompletionContentPartParam]]
225
226
227
228
229
230
231
232
233
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

234
235
236
237
238
239
    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""

240

241
242
243
244
245
ChatCompletionMessageParam = Union[
    OpenAIChatCompletionMessageParam,
    CustomChatCompletionMessageParam,
    OpenAIHarmonyMessage,
]
246
247


248
# TODO: Make fields ReadOnly once mypy supports it
249
250
251
252
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

253
    content: Union[Optional[str], list[dict[str, str]]]
254
255
256
257
258
259
260
261
262
263
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""
264
265


266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
282
283
284
285
286
        return (
            _is_var_access(node.node, varname)
            and isinstance(node.arg, jinja2.nodes.Const)
            and node.arg.value == key
        )
287
288
289
290
291
292
293
294
295
296
297
298
299

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: Optional[str] = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
300
        return node.node is not None and _is_var_or_elems_access(
301
302
            node.node, varname, key
        )
303
304
305
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

306
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
307
308
        node.arg, jinja2.nodes.Slice
    ):
309
310
        return _is_var_or_elems_access(node.node, varname, key)

311
    return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
340
        varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


381
@lru_cache(maxsize=32)
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


402
403
404
405
def resolve_mistral_chat_template(
    chat_template: Optional[str],
    **kwargs: Any,
) -> Optional[str]:
406
407
408
409
    if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
        raise ValueError(
            "'chat_template' or 'chat_template_kwargs' cannot be overridden "
            "for mistral tokenizer."
410
        )
411

412
413
    return None

414

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    model_config: ModelConfig,
) -> Optional[str]:
    cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
            trust_remote_code=model_config.trust_remote_code,
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


460
def resolve_hf_chat_template(
461
462
463
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
464
465
    *,
    model_config: ModelConfig,
466
467
468
469
470
471
472
) -> Optional[str]:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
473
        chat_template = _try_get_processor_chat_template(tokenizer, model_config)
474
475
        if chat_template is not None:
            return chat_template
476
477
478
479
480

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
481
482
483
484
485
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
486

487
488
489
490
491
492
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
493
        logger.info_once(
494
495
496
497
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
498
499
        chat_template = load_chat_template(path)
    else:
500
        logger.debug_once(
501
502
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
503
504

    return chat_template
505
506


507
508
def _resolve_chat_template_content_format(
    chat_template: Optional[str],
509
    tools: Optional[list[dict[str, Any]]],
510
    tokenizer: AnyTokenizer,
511
512
    *,
    model_config: ModelConfig,
513
514
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
515
        hf_chat_template = resolve_hf_chat_template(
516
517
518
            tokenizer,
            chat_template=chat_template,
            tools=tools,
519
            model_config=model_config,
520
        )
521
    else:
522
523
        hf_chat_template = None

524
525
526
527
528
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
529

530
531
532
533
534
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
535

536
    return detected_format
537
538
539


@lru_cache
540
def _log_chat_template_content_format(
541
542
    chat_template: Optional[str],
    given_format: ChatTemplateContentFormatOption,
543
544
    detected_format: ChatTemplateContentFormatOption,
):
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

562
563
564
565
566
567

def resolve_chat_template_content_format(
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
568
569
    *,
    model_config: ModelConfig,
570
) -> _ChatTemplateContentFormat:
571
572
573
    if given_format != "auto":
        return given_format

574
575
576
577
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
578
        model_config=model_config,
579
580
581
582
583
584
585
586
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

587
    return detected_format
588

589

590
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
591
592
593
594
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
595
596
597
598
599
600
601
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
602
603
        super().__init__()

604
605
        self._model_config = model_config
        self._tokenizer = tokenizer
606

607
        self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
608
        self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
609

610
611
612
613
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

614
    @cached_property
615
    def model_cls(self) -> type[SupportsMultiModal]:
616
        from vllm.model_executor.model_loader import get_model_cls
617

618
619
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
620

621
622
623
624
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

625
626
627
628
    @property
    def allowed_media_domains(self):
        return self._model_config.allowed_media_domains

629
630
631
632
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

633
634
635
636
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

637
    def add(
638
639
640
641
        self,
        modality: ModalityStr,
        item: Optional[_T],
        uuid: Optional[str] = None,
642
    ) -> Optional[str]:
643
644
645
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
646
647

        An optional uuid can be added which serves as a unique identifier of the
648
        media.
649
        """
650
        input_modality = modality.replace("_embeds", "")
651
        num_items = len(self._items_by_modality[modality]) + 1
652

653
        self.mm_processor.validate_num_items(input_modality, num_items)
654

655
        self._items_by_modality[modality].append(item)
656
        self._uuids_by_modality[modality].append(uuid)
657

658
        return self.model_cls.get_placeholder_str(modality, num_items)
659

660
661
662
663
664
665
    def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
        if not self._items_by_modality:
            return None
        mm_uuids = {}
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
666
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
667
668
669
670

        if "image_embeds" in uuids_by_modality:
            image_embeds_uuids = uuids_by_modality["image_embeds"]
            if len(image_embeds_uuids) > 1:
671
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
672
673
674
675
676
677
678
679
680
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
        return mm_uuids

681
682
683
684
685
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


686
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
687
    def all_mm_data(self) -> Optional[MultiModalDataDict]:
688
689
690
691
692
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
693
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
694
695
696
697

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
698
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
699
            mm_inputs["image"] = image_embeds_lst[0]
700
        if "image" in items_by_modality:
701
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
702
        if "audio" in items_by_modality:
703
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
704
        if "video" in items_by_modality:
705
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
706
        return mm_inputs
707
708
709
710
711

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


712
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
713
    async def all_mm_data(self) -> Optional[MultiModalDataDict]:
714
715
716
        if not self._items_by_modality:
            return None
        mm_inputs = {}
717
718
719
720
721
722
723
724
725
        items_by_modality = {}
        for modality, items in self._items_by_modality.items():
            coros = []
            for item in items:
                if item is not None:
                    coros.append(item)
                else:
                    coros.append(asyncio.sleep(0))
            items_by_modality[modality] = await asyncio.gather(*coros)
726

727
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
728
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
729
730
731
732

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
733
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
734
            mm_inputs["image"] = image_embeds_lst[0]
735
        if "image" in items_by_modality:
736
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
737
        if "audio" in items_by_modality:
738
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
739
        if "video" in items_by_modality:
740
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
741
        return mm_inputs
742
743
744
745
746
747
748
749
750

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

751
        # stores model placeholders list with corresponding
752
753
754
755
756
757
758
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

759
    def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]):
760
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
761
        if placeholder:
762
            self._placeholder_storage[mod_placeholder].append(placeholder)
763

764
765
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
766
767

    @abstractmethod
768
    def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
769
770
        raise NotImplementedError

771
    @abstractmethod
772
    def parse_image_embeds(
773
        self,
774
        image_embeds: Union[str, dict[str, str], None],
775
        uuid: Optional[str] = None,
776
    ) -> None:
777
778
        raise NotImplementedError

779
    @abstractmethod
780
    def parse_image_pil(
781
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
782
    ) -> None:
783
784
        raise NotImplementedError

785
    @abstractmethod
786
    def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
787
788
        raise NotImplementedError

789
    @abstractmethod
790
    def parse_input_audio(
791
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
792
    ) -> None:
793
794
        raise NotImplementedError

795
    @abstractmethod
796
    def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
797
798
        raise NotImplementedError

799
800
801
802
803
804

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
805
806
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
807
        self._connector = MediaConnector(
808
            media_io_kwargs=media_io_kwargs,
809
            allowed_local_media_path=tracker.allowed_local_media_path,
810
            allowed_media_domains=tracker.allowed_media_domains,
811
812
        )

813
    def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
814
        image = self._connector.fetch_image(image_url) if image_url else None
815

816
        placeholder = self._tracker.add("image", image, uuid)
817
        self._add_placeholder("image", placeholder)
818

819
    def parse_image_embeds(
820
        self,
821
        image_embeds: Union[str, dict[str, str], None],
822
        uuid: Optional[str] = None,
823
    ) -> None:
824
825
826
827
828
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
829
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
830
831
832

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
833
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
834

835
836
837
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

838
        self._add_placeholder("image", placeholder)
839

840
    def parse_image_pil(
841
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
842
843
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
844
        self._add_placeholder("image", placeholder)
845

846
    def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
847
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
848

849
        placeholder = self._tracker.add("audio", audio, uuid)
850
        self._add_placeholder("audio", placeholder)
851

852
    def parse_input_audio(
853
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
854
    ) -> None:
855
856
857
858
859
860
861
862
863
864
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
865

866
        return self.parse_audio(audio_url, uuid)
867

868
869
    def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
870

871
        placeholder = self._tracker.add("video", video, uuid)
872
        self._add_placeholder("video", placeholder)
873

874
875
876
877
878
879

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
880
881
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
882
        self._connector = MediaConnector(
883
            media_io_kwargs=media_io_kwargs,
884
            allowed_local_media_path=tracker.allowed_local_media_path,
885
            allowed_media_domains=tracker.allowed_media_domains,
886
        )
887

888
889
    def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
        image_coro = self._connector.fetch_image_async(image_url) if image_url else None
890

891
        placeholder = self._tracker.add("image", image_coro, uuid)
892
        self._add_placeholder("image", placeholder)
893

894
    def parse_image_embeds(
895
        self,
896
        image_embeds: Union[str, dict[str, str], None],
897
        uuid: Optional[str] = None,
898
    ) -> None:
899
        future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future()
900
901
902
903
904
905
906
907
908

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
909
            embedding = self._connector.fetch_image_embedding(image_embeds)
910
911
            future.set_result(embedding)

912
913
914
        if image_embeds is None:
            future.set_result(None)

915
        placeholder = self._tracker.add("image_embeds", future, uuid)
916
        self._add_placeholder("image", placeholder)
917

918
    def parse_image_pil(
919
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
920
    ) -> None:
921
922
923
924
925
        future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
926

927
        placeholder = self._tracker.add("image", future, uuid)
928
        self._add_placeholder("image", placeholder)
929

930
931
    def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
        audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
932

933
        placeholder = self._tracker.add("audio", audio_coro, uuid)
934
        self._add_placeholder("audio", placeholder)
935

936
    def parse_input_audio(
937
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
938
    ) -> None:
939
940
941
942
943
944
945
946
947
948
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
949

950
        return self.parse_audio(audio_url, uuid)
951

952
    def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
953
954
955
956
957
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
958

959
        placeholder = self._tracker.add("video", video, uuid)
960
        self._add_placeholder("video", placeholder)
961

962

963
964
965
966
967
968
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
969
        raise FileNotFoundError("the supplied chat template path doesn't exist")
970
971
972

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
973
974
975
976
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
977
978
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
979
980
                f"appears path-like, but doesn't exist!"
            )
981
982

    else:
983
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
984
985


986
def _load_chat_template(
987
988
989
990
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
991
992
    if chat_template is None:
        return None
993
994
995

    if is_literal:
        if isinstance(chat_template, Path):
996
997
998
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
999

1000
        return chat_template
1001

1002
    try:
1003
        with open(chat_template) as f:
1004
            return f.read()
1005
    except OSError as e:
1006
1007
1008
        if isinstance(chat_template, Path):
            raise

1009
1010
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1011
1012
1013
1014
1015
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
1016
            raise ValueError(msg) from e
1017

1018
1019
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1032
1033


1034
1035
1036
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1037
1038
1039
1040
1041
1042
1043
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1044
# TODO: Let user specify how to insert multimodal tokens into prompt
1045
# (similar to chat template)
1046
1047
1048
1049
1050
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1051
    """Combine multimodal prompts for a multimodal language model."""
1052

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1070
    # Look through the text prompt to check for missing placeholders
1071
    missing_placeholders: list[str] = []
1072
1073
1074
1075
1076
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1077
1078
1079
1080
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1081
1082
                "when manually placing image placeholders.",
                interleave_strings,
1083
1084
            )
            logger.debug("Input prompt: %s", text_prompt)
1085
1086
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1087
1088
                "actual multimodal data items."
            )
1089

1090
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1091

1092
1093
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1094
    return "\n".join(missing_placeholders + [text_prompt])
1095
1096


1097
1098
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1099
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1100
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1101
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1102
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1103
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1104
1105
1106
1107
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1108

1109
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1110
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
1111

1112
# Define a mapping from part types to their corresponding parsing functions.
1113
MM_PARSER_MAP: dict[
1114
1115
1116
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1117
1118
1119
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1120
1121
1122
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1123
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1124
1125
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1126
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1127
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1128
1129
1130
1131
}


def _parse_chat_message_content_mm_part(
1132
1133
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1134
    """
1135
    Parses a given multi-modal content part based on its type.
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1149
1150
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1151
    part_type = part.get("type", None)
1152
    uuid = part.get("uuid", None)
1153

1154
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1155
1156
1157
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1158
1159
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1160
            logger.warning(
1161
                "'image_url.detail' is currently not supported and will be ignored."
1162
            )
1163
1164
1165
1166

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1167
    # 'type' is required field by pydantic
1168
1169
    if part_type is None or uuid is not None:
        if "image_url" in part:
1170
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1171
1172
1173
1174
1175
1176
1177
1178
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1179
            image_params = cast(  # type: ignore
1180
1181
1182
1183
1184
1185
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1186
            image_params = cast(  # type: ignore
1187
1188
1189
1190
1191
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
        if "audio_url" in part:
1192
            audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
1193
1194
1195
1196
1197
1198
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1199
        if part.get("input_audio") is not None:
1200
            input_audio_params = cast(dict[str, str], part)
1201
            return "input_audio", input_audio_params
1202
        if "video_url" in part:
1203
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1204
1205
1206
1207
1208
1209
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1210
1211
1212
1213
1214
1215
1216
1217
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1218
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1219
1220
1221
    "text",
    "refusal",
)
1222

1223

1224
1225
1226
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1227
    mm_tracker: BaseMultiModalItemTracker,
1228
1229
    *,
    wrap_dicts: bool,
1230
    interleave_strings: bool,
1231
) -> list[ConversationMessage]:
1232
    content = list[_ContentPart]()
1233

1234
    mm_parser = mm_tracker.create_parser()
1235
1236

    for part in parts:
1237
        parse_res = _parse_chat_message_content_part(
1238
1239
1240
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1241
            interleave_strings=interleave_strings,
1242
        )
1243
1244
        if parse_res:
            content.append(parse_res)
1245

1246
    if wrap_dicts:
1247
        # Parsing wraps images and texts as interleaved dictionaries
1248
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1249
    texts = cast(list[str], content)
1250
1251
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1252
1253
1254
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1255
1256
1257
    else:
        text_prompt = "\n".join(texts)

1258
1259
1260
1261
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1262
1263
1264
1265
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1266
    interleave_strings: bool,
1267
) -> Optional[_ContentPart]:
1268
1269
1270
1271
1272
1273
1274
1275
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1276
        return part
1277
1278
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1279
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1280
    # content is None, log a warning and skip
1281
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1282
        logger.warning(
1283
            "Skipping multimodal part '%s' (type: '%s') "
1284
1285
1286
1287
            "with empty / unparsable content.",
            part,
            part_type,
        )
1288
1289
        return None

Julien Denize's avatar
Julien Denize committed
1290
    if part_type in ("text", "input_text", "refusal", "thinking"):
1291
1292
        str_content = cast(str, content)
        if wrap_dicts:
1293
            return {"type": "text", "text": str_content}
1294
1295
        else:
            return str_content
1296

1297
1298
1299
1300
1301
1302
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1303
    modality = None
1304
    if part_type == "image_pil":
1305
        image_content = cast(Image.Image, content) if content is not None else None
1306
        mm_parser.parse_image_pil(image_content, uuid)
1307
        modality = "image"
1308
    elif part_type in ("image_url", "input_image"):
1309
        str_content = cast(str, content)
1310
        mm_parser.parse_image(str_content, uuid)
1311
1312
        modality = "image"
    elif part_type == "image_embeds":
1313
1314
1315
1316
        if content is not None:
            content = cast(Union[str, dict[str, str]], content)
        else:
            content = None
1317
        mm_parser.parse_image_embeds(content, uuid)
1318
1319
        modality = "image"
    elif part_type == "audio_url":
1320
        str_content = cast(str, content)
1321
        mm_parser.parse_audio(str_content, uuid)
1322
1323
        modality = "audio"
    elif part_type == "input_audio":
1324
        dict_content = cast(InputAudio, content)
1325
        mm_parser.parse_input_audio(dict_content, uuid)
1326
1327
        modality = "audio"
    elif part_type == "video_url":
1328
        str_content = cast(str, content)
1329
        mm_parser.parse_video(str_content, uuid)
1330
1331
1332
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1333

1334
1335
1336
    return (
        {"type": modality}
        if wrap_dicts
1337
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1338
    )
1339
1340


1341
1342
1343
1344
1345
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1346
def _parse_chat_message_content(
1347
1348
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1349
    content_format: _ChatTemplateContentFormat,
1350
    interleave_strings: bool,
1351
) -> list[ConversationMessage]:
1352
1353
1354
1355
    role = message["role"]
    content = message.get("content")

    if content is None:
1356
1357
        content = []
    elif isinstance(content, str):
1358
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1359
    result = _parse_chat_message_content_parts(
1360
1361
        role,
        content,  # type: ignore
1362
        mm_tracker,
1363
        wrap_dicts=(content_format == "openai"),
1364
        interleave_strings=interleave_strings,
1365
    )
1366

1367
    for result_msg in result:
1368
        if role == "assistant":
1369
1370
            parsed_msg = _AssistantParser(message)

1371
1372
1373
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1374
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1386

1387
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1388
1389
1390
1391
1392
1393
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1394
1395
1396
1397
1398
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1399
            for item in message["tool_calls"]:
1400
1401
1402
1403
1404
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
                    item["function"]["arguments"] = json.loads(content)
                else:
                    item["function"]["arguments"] = {}
1405
1406


1407
def parse_chat_messages(
1408
    messages: list[ChatCompletionMessageParam],
1409
    model_config: ModelConfig,
1410
    tokenizer: AnyTokenizer,
1411
    content_format: _ChatTemplateContentFormat,
1412
1413
1414
1415
1416
) -> tuple[
    list[ConversationMessage],
    Optional[MultiModalDataDict],
    Optional[MultiModalUUIDDict],
]:
1417
    conversation: list[ConversationMessage] = []
1418
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1419
1420

    for msg in messages:
1421
1422
1423
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1424
            content_format,
1425
1426
1427
1428
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1429
            ),
1430
        )
1431

1432
        conversation.extend(sub_messages)
1433

1434
1435
    _postprocess_messages(conversation)

1436
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1437
1438


1439
def parse_chat_messages_futures(
1440
    messages: list[ChatCompletionMessageParam],
1441
1442
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1443
    content_format: _ChatTemplateContentFormat,
1444
1445
1446
1447
1448
) -> tuple[
    list[ConversationMessage],
    Awaitable[Optional[MultiModalDataDict]],
    Optional[MultiModalUUIDDict],
]:
1449
    conversation: list[ConversationMessage] = []
1450
1451
1452
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1453
1454
1455
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1456
            content_format,
1457
1458
1459
1460
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1461
            ),
1462
        )
1463
1464
1465

        conversation.extend(sub_messages)

1466
1467
    _postprocess_messages(conversation)

1468
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1469
1470


1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
def _resolve_chat_template_kwargs(
    chat_template: str,
):
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
    return template_vars


_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)


1500
1501
1502
1503
1504
1505
def resolve_chat_template_kwargs(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
    fn_kw = {
1506
1507
        k
        for k in chat_template_kwargs
1508
1509
1510
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }

1511
    template_vars = _cached_resolve_chat_template_kwargs(chat_template)
1512
1513
1514
1515
1516

    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template"}
    accept_vars = (fn_kw | template_vars) - unexpected_vars
1517
    return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
1518
1519


1520
1521
def apply_hf_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1522
    conversation: list[ConversationMessage],
1523
    chat_template: Optional[str],
1524
    tools: Optional[list[dict[str, Any]]],
1525
    *,
1526
    model_config: ModelConfig,
1527
1528
    tokenize: bool = False,  # Different from HF's default
    **kwargs: Any,
1529
) -> str:
1530
    hf_chat_template = resolve_hf_chat_template(
1531
1532
1533
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1534
        model_config=model_config,
1535
    )
1536

1537
    if hf_chat_template is None:
1538
1539
1540
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1541
1542
            "does not define one."
        )
1543

1544
    try:
1545
1546
1547
1548
1549
        resolved_kwargs = resolve_chat_template_kwargs(
            tokenizer=tokenizer,
            chat_template=hf_chat_template,
            chat_template_kwargs=kwargs,
        )
1550
1551
1552
1553
1554
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
            tokenize=tokenize,
1555
            **resolved_kwargs,
1556
        )
1557

1558
1559
1560
1561
1562
1563
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1564
1565
            "An error occurred in `transformers` while applying chat template"
        )
1566
        raise ValueError(str(e)) from e
1567

1568

1569
1570
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1571
    messages: list[ChatCompletionMessageParam],
1572
1573
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
1574
    **kwargs: Any,
1575
) -> list[int]:
1576
1577
    from mistral_common.exceptions import MistralCommonException

1578
1579
1580
1581
1582
1583
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1584

1585
1586
1587
1588
1589
1590
1591
1592
1593
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1594
    # properly caught in the preprocessing_input step
1595
    except (AssertionError, MistralCommonException) as e:
1596
        raise ValueError(str(e)) from e
1597
1598
1599
1600
1601
1602
1603

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1604
1605
            "An error occurred in `mistral_common` while applying chat template"
        )
1606
        raise ValueError(str(e)) from e
1607

1608

1609
1610
1611
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1612
1613
1614
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1615
1616
1617
    return idx


1618
1619
1620
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1621
1622
1623
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"