chat_utils.py 55.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
from abc import ABC, abstractmethod
7
from collections import Counter, defaultdict, deque
8
from collections.abc import Awaitable, Iterable
9
from functools import cached_property, lru_cache, partial
10
from pathlib import Path
11
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast
12

13
14
15
import jinja2
import jinja2.ext
import jinja2.meta
16
import jinja2.nodes
17
18
import jinja2.parser
import jinja2.sandbox
19
import transformers.utils.chat_template_utils as hf_chat_utils
20

21
22
23
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (
24
25
26
27
28
29
30
31
32
33
34
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
35
from openai.types.chat import (
36
37
38
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
39
from openai.types.responses import ResponseInputImageParam
40
from openai_harmony import Message as OpenAIHarmonyMessage
41
42
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
43

44
# yapf: enable
45
46
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin

47
# pydantic needs the TypedDict from typing_extensions
48
from typing_extensions import Required, TypeAlias, TypedDict
49

50
from vllm.config import ModelConfig
51
from vllm.logger import init_logger
52
from vllm.model_executor.models import SupportsMultiModal
53
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
54
from vllm.multimodal.utils import MediaConnector
55

56
# yapf: disable
57
58
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path

59
# yapf: enable
60
from vllm.transformers_utils.processor import cached_get_processor
61
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
62
from vllm.utils import random_uuid, supports_kw
63
64
65

logger = init_logger(__name__)

66
67
68
69
70
71
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


87
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
88
    image_embeds: Optional[Union[str, dict[str, str]]]
89
90
91
92
93
94
95
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
96
97
98
99
100
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
101
102


103
104
105
106
107
108
109
110
111
112
113
114
115
116
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


117
118
119
120
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
121

122
123
124
125
126
127
128
129
130
131
132
133
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
134

135
    image_pil: Optional[PILImage]
136
137
138
139
140
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
141
142


143
144
145
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
146

147
148
149
150
151
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
152

153
    image_url: Optional[str]
154
155
156
157
158
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
159
160
161
162


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
163

164
165
166
167
168
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
169

170
    audio_url: Optional[str]
171
172


173
174
175
176
177
178
179
180
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
181

182
    video_url: Optional[str]
183
184
185
186
187
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
188
189


Julien Denize's avatar
Julien Denize committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


211
ChatCompletionContentPartParam: TypeAlias = Union[
212
213
    OpenAIChatCompletionContentPartParam,
    ChatCompletionContentPartAudioParam,
214
    ChatCompletionContentPartInputAudioParam,
215
216
    ChatCompletionContentPartVideoParam,
    ChatCompletionContentPartRefusalParam,
217
    CustomChatCompletionContentPILImageParam,
218
    CustomChatCompletionContentSimpleImageParam,
219
    ChatCompletionContentPartImageEmbedsParam,
220
    CustomChatCompletionContentSimpleAudioParam,
221
222
223
224
    CustomChatCompletionContentSimpleVideoParam,
    str,
    CustomThinkCompletionContentParam,
]
225
226
227
228


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
229

230
231
232
    role: Required[str]
    """The role of the message's author."""

233
    content: Union[str, list[ChatCompletionContentPartParam]]
234
235
236
237
238
239
240
241
242
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

243
244
245
246
247
248
    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""

249

250
251
252
253
254
ChatCompletionMessageParam = Union[
    OpenAIChatCompletionMessageParam,
    CustomChatCompletionMessageParam,
    OpenAIHarmonyMessage,
]
255
256


257
# TODO: Make fields ReadOnly once mypy supports it
258
259
260
261
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

262
    content: Union[Optional[str], list[dict[str, str]]]
263
264
265
266
267
268
269
270
271
272
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""
273
274


275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
291
292
293
294
295
        return (
            _is_var_access(node.node, varname)
            and isinstance(node.arg, jinja2.nodes.Const)
            and node.arg.value == key
        )
296
297
298
299
300
301
302
303
304
305
306
307
308

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: Optional[str] = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
309
        return node.node is not None and _is_var_or_elems_access(
310
311
            node.node, varname, key
        )
312
313
314
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

315
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
316
317
        node.arg, jinja2.nodes.Slice
    ):
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        return _is_var_or_elems_access(node.node, varname, key)

    # yapf: disable
    return (
        _is_attr_access(node, varname, key) if key
        else _is_var_access(node, varname)
    ) # yapf: enable


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
353
        varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


394
@lru_cache(maxsize=32)
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


415
416
417
418
419
420
def resolve_mistral_chat_template(
    chat_template: Optional[str],
    **kwargs: Any,
) -> Optional[str]:
    if chat_template is not None:
        logger.warning_once(
421
422
            "'chat_template' cannot be overridden for mistral tokenizer."
        )
423
424
425
    if "add_generation_prompt" in kwargs:
        logger.warning_once(
            "'add_generation_prompt' is not supported for mistral tokenizer, "
426
427
            "so it will be ignored."
        )
428
429
430
    if "continue_final_message" in kwargs:
        logger.warning_once(
            "'continue_final_message' is not supported for mistral tokenizer, "
431
432
            "so it will be ignored."
        )
433
434
    return None

435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    model_config: ModelConfig,
) -> Optional[str]:
    cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
            trust_remote_code=model_config.trust_remote_code,
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


481
def resolve_hf_chat_template(
482
483
484
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
485
486
    *,
    model_config: ModelConfig,
487
488
489
490
491
492
493
) -> Optional[str]:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
494
        chat_template = _try_get_processor_chat_template(tokenizer, model_config)
495
496
        if chat_template is not None:
            return chat_template
497
498
499
500
501

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
502
503
504
505
506
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
507

508
509
510
511
512
513
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
514
515
516
517
518
        logger.info(
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
519
520
        chat_template = load_chat_template(path)
    else:
521
522
523
        logger.debug(
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
524
525

    return chat_template
526
527


528
529
def _resolve_chat_template_content_format(
    chat_template: Optional[str],
530
    tools: Optional[list[dict[str, Any]]],
531
    tokenizer: AnyTokenizer,
532
533
    *,
    model_config: ModelConfig,
534
535
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
536
        hf_chat_template = resolve_hf_chat_template(
537
538
539
            tokenizer,
            chat_template=chat_template,
            tools=tools,
540
            model_config=model_config,
541
        )
542
    else:
543
544
        hf_chat_template = None

545
546
547
548
549
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
550

551
552
553
554
555
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
556

557
    return detected_format
558
559
560


@lru_cache
561
def _log_chat_template_content_format(
562
563
    chat_template: Optional[str],
    given_format: ChatTemplateContentFormatOption,
564
565
    detected_format: ChatTemplateContentFormatOption,
):
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

583
584
585
586
587
588

def resolve_chat_template_content_format(
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
589
590
    *,
    model_config: ModelConfig,
591
) -> _ChatTemplateContentFormat:
592
593
594
    if given_format != "auto":
        return given_format

595
596
597
598
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
599
        model_config=model_config,
600
601
602
603
604
605
606
607
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

608
    return detected_format
609

610

611
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
612
613
614
615
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
616
617
618
619
620
621
622
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
623
624
        super().__init__()

625
626
        self._model_config = model_config
        self._tokenizer = tokenizer
627

628
        self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
629
        self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
630

631
632
633
634
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

635
    @cached_property
636
    def model_cls(self) -> type[SupportsMultiModal]:
637
        from vllm.model_executor.model_loader import get_model_cls
638

639
640
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
641

642
643
644
645
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

646
647
648
649
    @property
    def allowed_media_domains(self):
        return self._model_config.allowed_media_domains

650
651
652
653
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

654
655
656
657
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

658
    def add(
659
660
661
662
        self,
        modality: ModalityStr,
        item: Optional[_T],
        uuid: Optional[str] = None,
663
    ) -> Optional[str]:
664
665
666
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
667
668

        An optional uuid can be added which serves as a unique identifier of the
669
        media.
670
        """
671
        input_modality = modality.replace("_embeds", "")
672
        num_items = len(self._items_by_modality[modality]) + 1
673

674
        self.mm_processor.validate_num_items(input_modality, num_items)
675

676
        self._items_by_modality[modality].append(item)
677
        self._uuids_by_modality[modality].append(uuid)
678

679
        return self.model_cls.get_placeholder_str(modality, num_items)
680

681
682
683
684
685
686
    def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
        if not self._items_by_modality:
            return None
        mm_uuids = {}
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
687
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
688
689
690
691

        if "image_embeds" in uuids_by_modality:
            image_embeds_uuids = uuids_by_modality["image_embeds"]
            if len(image_embeds_uuids) > 1:
692
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
693
694
695
696
697
698
699
700
701
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
        return mm_uuids

702
703
704
705
706
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


707
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
708
    def all_mm_data(self) -> Optional[MultiModalDataDict]:
709
710
711
712
713
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
714
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
715
716
717
718

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
719
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
720
            mm_inputs["image"] = image_embeds_lst[0]
721
        if "image" in items_by_modality:
722
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
723
        if "audio" in items_by_modality:
724
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
725
        if "video" in items_by_modality:
726
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
727
        return mm_inputs
728
729
730
731
732

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


733
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
734
    async def all_mm_data(self) -> Optional[MultiModalDataDict]:
735
736
737
        if not self._items_by_modality:
            return None
        mm_inputs = {}
738
739
740
741
742
743
744
745
746
        items_by_modality = {}
        for modality, items in self._items_by_modality.items():
            coros = []
            for item in items:
                if item is not None:
                    coros.append(item)
                else:
                    coros.append(asyncio.sleep(0))
            items_by_modality[modality] = await asyncio.gather(*coros)
747

748
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
749
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
750
751
752
753

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
754
                raise ValueError("Only one message can have {'type': 'image_embeds'}")
755
            mm_inputs["image"] = image_embeds_lst[0]
756
        if "image" in items_by_modality:
757
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
758
        if "audio" in items_by_modality:
759
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
760
        if "video" in items_by_modality:
761
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
762
        return mm_inputs
763
764
765
766
767
768
769
770
771

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

772
        # stores model placeholders list with corresponding
773
774
775
776
777
778
779
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

780
    def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]):
781
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
782
        if placeholder:
783
            self._placeholder_storage[mod_placeholder].append(placeholder)
784

785
786
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
787
788

    @abstractmethod
789
    def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
790
791
        raise NotImplementedError

792
    @abstractmethod
793
    def parse_image_embeds(
794
        self,
795
        image_embeds: Union[str, dict[str, str], None],
796
        uuid: Optional[str] = None,
797
    ) -> None:
798
799
        raise NotImplementedError

800
    @abstractmethod
801
    def parse_image_pil(
802
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
803
    ) -> None:
804
805
        raise NotImplementedError

806
    @abstractmethod
807
    def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
808
809
        raise NotImplementedError

810
    @abstractmethod
811
    def parse_input_audio(
812
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
813
    ) -> None:
814
815
        raise NotImplementedError

816
    @abstractmethod
817
    def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
818
819
        raise NotImplementedError

820
821
822
823
824
825

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
826
827
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
828
        self._connector = MediaConnector(
829
            media_io_kwargs=media_io_kwargs,
830
            allowed_local_media_path=tracker.allowed_local_media_path,
831
            allowed_media_domains=tracker.allowed_media_domains,
832
833
        )

834
    def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
835
        image = self._connector.fetch_image(image_url) if image_url else None
836

837
        placeholder = self._tracker.add("image", image, uuid)
838
        self._add_placeholder("image", placeholder)
839

840
    def parse_image_embeds(
841
        self,
842
        image_embeds: Union[str, dict[str, str], None],
843
        uuid: Optional[str] = None,
844
    ) -> None:
845
846
847
848
849
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
850
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
851
852
853

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
854
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
855

856
857
858
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

859
        self._add_placeholder("image", placeholder)
860

861
    def parse_image_pil(
862
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
863
864
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
865
        self._add_placeholder("image", placeholder)
866

867
    def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
868
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
869

870
        placeholder = self._tracker.add("audio", audio, uuid)
871
        self._add_placeholder("audio", placeholder)
872

873
    def parse_input_audio(
874
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
875
    ) -> None:
876
877
878
879
880
881
882
883
884
885
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
886

887
        return self.parse_audio(audio_url, uuid)
888

889
890
    def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
891

892
        placeholder = self._tracker.add("video", video, uuid)
893
        self._add_placeholder("video", placeholder)
894

895
896
897
898
899
900

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
901
902
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
903
        self._connector = MediaConnector(
904
            media_io_kwargs=media_io_kwargs,
905
            allowed_local_media_path=tracker.allowed_local_media_path,
906
            allowed_media_domains=tracker.allowed_media_domains,
907
        )
908

909
910
    def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
        image_coro = self._connector.fetch_image_async(image_url) if image_url else None
911

912
        placeholder = self._tracker.add("image", image_coro, uuid)
913
        self._add_placeholder("image", placeholder)
914

915
    def parse_image_embeds(
916
        self,
917
        image_embeds: Union[str, dict[str, str], None],
918
        uuid: Optional[str] = None,
919
    ) -> None:
920
        future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future()
921
922
923
924
925
926
927
928
929

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
930
            embedding = self._connector.fetch_image_embedding(image_embeds)
931
932
            future.set_result(embedding)

933
934
935
        if image_embeds is None:
            future.set_result(None)

936
        placeholder = self._tracker.add("image_embeds", future, uuid)
937
        self._add_placeholder("image", placeholder)
938

939
    def parse_image_pil(
940
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
941
    ) -> None:
942
943
944
945
946
        future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
947

948
        placeholder = self._tracker.add("image", future, uuid)
949
        self._add_placeholder("image", placeholder)
950

951
952
    def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
        audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
953

954
        placeholder = self._tracker.add("audio", audio_coro, uuid)
955
        self._add_placeholder("audio", placeholder)
956

957
    def parse_input_audio(
958
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
959
    ) -> None:
960
961
962
963
964
965
966
967
968
969
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
970

971
        return self.parse_audio(audio_url, uuid)
972

973
    def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
974
975
976
977
978
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
979

980
        placeholder = self._tracker.add("video", video, uuid)
981
        self._add_placeholder("video", placeholder)
982

983

984
985
986
987
988
989
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
990
        raise FileNotFoundError("the supplied chat template path doesn't exist")
991
992
993

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
994
995
996
997
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
998
999
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
1000
1001
                f"appears path-like, but doesn't exist!"
            )
1002
1003

    else:
1004
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1005
1006


1007
def _load_chat_template(
1008
1009
1010
1011
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
1012
1013
    if chat_template is None:
        return None
1014
1015
1016

    if is_literal:
        if isinstance(chat_template, Path):
1017
1018
1019
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1020

1021
        return chat_template
1022

1023
    try:
1024
        with open(chat_template) as f:
1025
            return f.read()
1026
    except OSError as e:
1027
1028
1029
        if isinstance(chat_template, Path):
            raise

1030
1031
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1032
1033
1034
1035
1036
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
1037
            raise ValueError(msg) from e
1038

1039
1040
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1053
1054


1055
1056
1057
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1058
1059
1060
1061
1062
1063
1064
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1065
# TODO: Let user specify how to insert multimodal tokens into prompt
1066
# (similar to chat template)
1067
1068
1069
1070
1071
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1072
    """Combine multimodal prompts for a multimodal language model."""
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1091
    # Look through the text prompt to check for missing placeholders
1092
    missing_placeholders: list[str] = []
1093
1094
1095
1096
1097
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1098
1099
1100
1101
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1102
1103
                "when manually placing image placeholders.",
                interleave_strings,
1104
1105
            )
            logger.debug("Input prompt: %s", text_prompt)
1106
1107
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1108
1109
                "actual multimodal data items."
            )
1110

1111
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1112

1113
1114
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1115
    return "\n".join(missing_placeholders + [text_prompt])
1116
1117


1118
1119
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1120
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1121
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1122
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1123
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1124
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1125
1126
1127
1128
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1129

1130
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1131
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
1132

1133
# Define a mapping from part types to their corresponding parsing functions.
1134
MM_PARSER_MAP: dict[
1135
1136
1137
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1138
1139
1140
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1141
1142
1143
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1144
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1145
1146
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1147
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1148
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1149
1150
1151
1152
}


def _parse_chat_message_content_mm_part(
1153
1154
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1155
    """
1156
    Parses a given multi-modal content part based on its type.
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1170
1171
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1172
    part_type = part.get("type", None)
1173
    uuid = part.get("uuid", None)
1174

1175
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1176
1177
1178
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1179
1180
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1181
            logger.warning(
1182
                "'image_url.detail' is currently not supported and will be ignored."
1183
            )
1184
1185
1186
1187

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1188
    # 'type' is required field by pydantic
1189
1190
    if part_type is None or uuid is not None:
        if "image_url" in part:
1191
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1192
1193
1194
1195
1196
1197
1198
1199
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1200
            image_params = cast(  # type: ignore
1201
1202
1203
1204
1205
1206
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1207
            image_params = cast(  # type: ignore
1208
1209
1210
1211
1212
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
        if "audio_url" in part:
1213
            audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
1214
1215
1216
1217
1218
1219
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1220
        if part.get("input_audio") is not None:
1221
            input_audio_params = cast(dict[str, str], part)
1222
            return "input_audio", input_audio_params
1223
        if "video_url" in part:
1224
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1225
1226
1227
1228
1229
1230
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1231
1232
1233
1234
1235
1236
1237
1238
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1239
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1240
1241
1242
    "text",
    "refusal",
)
1243

1244

1245
1246
1247
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1248
    mm_tracker: BaseMultiModalItemTracker,
1249
1250
    *,
    wrap_dicts: bool,
1251
    interleave_strings: bool,
1252
) -> list[ConversationMessage]:
1253
    content = list[_ContentPart]()
1254

1255
    mm_parser = mm_tracker.create_parser()
1256
1257

    for part in parts:
1258
        parse_res = _parse_chat_message_content_part(
1259
1260
1261
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1262
            interleave_strings=interleave_strings,
1263
        )
1264
1265
        if parse_res:
            content.append(parse_res)
1266

1267
    if wrap_dicts:
1268
        # Parsing wraps images and texts as interleaved dictionaries
1269
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1270
    texts = cast(list[str], content)
1271
1272
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1273
1274
1275
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1276
1277
1278
    else:
        text_prompt = "\n".join(texts)

1279
1280
1281
1282
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1283
1284
1285
1286
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1287
    interleave_strings: bool,
1288
) -> Optional[_ContentPart]:
1289
1290
1291
1292
1293
1294
1295
1296
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1297
        return part
1298
1299
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1300
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1301
    # content is None, log a warning and skip
1302
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1303
        logger.warning(
1304
            "Skipping multimodal part '%s' (type: '%s') "
1305
1306
1307
1308
            "with empty / unparsable content.",
            part,
            part_type,
        )
1309
1310
        return None

Julien Denize's avatar
Julien Denize committed
1311
    if part_type in ("text", "input_text", "refusal", "thinking"):
1312
1313
        str_content = cast(str, content)
        if wrap_dicts:
1314
            return {"type": "text", "text": str_content}
1315
1316
        else:
            return str_content
1317

1318
1319
1320
1321
1322
1323
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1324
    modality = None
1325
    if part_type == "image_pil":
1326
1327
1328
1329
        if content is not None:
            image_content = cast(Image.Image, content)
        else:
            image_content = None
1330
        mm_parser.parse_image_pil(image_content, uuid)
1331
        modality = "image"
1332
    elif part_type in ("image_url", "input_image"):
1333
        str_content = cast(str, content)
1334
        mm_parser.parse_image(str_content, uuid)
1335
1336
        modality = "image"
    elif part_type == "image_embeds":
1337
1338
1339
1340
        if content is not None:
            content = cast(Union[str, dict[str, str]], content)
        else:
            content = None
1341
        mm_parser.parse_image_embeds(content, uuid)
1342
1343
        modality = "image"
    elif part_type == "audio_url":
1344
        str_content = cast(str, content)
1345
        mm_parser.parse_audio(str_content, uuid)
1346
1347
        modality = "audio"
    elif part_type == "input_audio":
1348
        dict_content = cast(InputAudio, content)
1349
        mm_parser.parse_input_audio(dict_content, uuid)
1350
1351
        modality = "audio"
    elif part_type == "video_url":
1352
        str_content = cast(str, content)
1353
        mm_parser.parse_video(str_content, uuid)
1354
1355
1356
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1357

1358
1359
1360
    return (
        {"type": modality}
        if wrap_dicts
1361
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1362
    )
1363
1364


1365
1366
1367
1368
1369
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1370
def _parse_chat_message_content(
1371
1372
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1373
    content_format: _ChatTemplateContentFormat,
1374
    interleave_strings: bool,
1375
) -> list[ConversationMessage]:
1376
1377
1378
1379
    role = message["role"]
    content = message.get("content")

    if content is None:
1380
1381
        content = []
    elif isinstance(content, str):
1382
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1383
    result = _parse_chat_message_content_parts(
1384
1385
        role,
        content,  # type: ignore
1386
        mm_tracker,
1387
        wrap_dicts=(content_format == "openai"),
1388
        interleave_strings=interleave_strings,
1389
    )
1390

1391
    for result_msg in result:
1392
        if role == "assistant":
1393
1394
            parsed_msg = _AssistantParser(message)

1395
1396
1397
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1398
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1410

1411
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1412
1413
1414
1415
1416
1417
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1418
1419
1420
1421
1422
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1423
            for item in message["tool_calls"]:
1424
1425
1426
1427
1428
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
                    item["function"]["arguments"] = json.loads(content)
                else:
                    item["function"]["arguments"] = {}
1429
1430


1431
def parse_chat_messages(
1432
    messages: list[ChatCompletionMessageParam],
1433
    model_config: ModelConfig,
1434
    tokenizer: AnyTokenizer,
1435
    content_format: _ChatTemplateContentFormat,
1436
1437
1438
1439
1440
) -> tuple[
    list[ConversationMessage],
    Optional[MultiModalDataDict],
    Optional[MultiModalUUIDDict],
]:
1441
    conversation: list[ConversationMessage] = []
1442
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1443
1444

    for msg in messages:
1445
1446
1447
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1448
            content_format,
1449
1450
1451
1452
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1453
            ),
1454
        )
1455

1456
        conversation.extend(sub_messages)
1457

1458
1459
    _postprocess_messages(conversation)

1460
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1461
1462


1463
def parse_chat_messages_futures(
1464
    messages: list[ChatCompletionMessageParam],
1465
1466
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1467
    content_format: _ChatTemplateContentFormat,
1468
1469
1470
1471
1472
) -> tuple[
    list[ConversationMessage],
    Awaitable[Optional[MultiModalDataDict]],
    Optional[MultiModalUUIDDict],
]:
1473
    conversation: list[ConversationMessage] = []
1474
1475
1476
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1477
1478
1479
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1480
            content_format,
1481
1482
1483
1484
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1485
            ),
1486
        )
1487
1488
1489

        conversation.extend(sub_messages)

1490
1491
    _postprocess_messages(conversation)

1492
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1493
1494


1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
def _resolve_chat_template_kwargs(
    chat_template: str,
):
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
    return template_vars


_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)


1524
1525
1526
1527
1528
1529
def resolve_chat_template_kwargs(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
    fn_kw = {
1530
1531
        k
        for k in chat_template_kwargs
1532
1533
1534
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }

1535
    template_vars = _cached_resolve_chat_template_kwargs(chat_template)
1536
1537
1538
1539
1540

    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template"}
    accept_vars = (fn_kw | template_vars) - unexpected_vars
1541
    return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
1542
1543


1544
1545
def apply_hf_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1546
    conversation: list[ConversationMessage],
1547
    chat_template: Optional[str],
1548
    tools: Optional[list[dict[str, Any]]],
1549
    *,
1550
    model_config: ModelConfig,
1551
1552
    tokenize: bool = False,  # Different from HF's default
    **kwargs: Any,
1553
) -> str:
1554
    hf_chat_template = resolve_hf_chat_template(
1555
1556
1557
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1558
        model_config=model_config,
1559
    )
1560

1561
    if hf_chat_template is None:
1562
1563
1564
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1565
1566
            "does not define one."
        )
1567

1568
    try:
1569
1570
1571
1572
1573
        resolved_kwargs = resolve_chat_template_kwargs(
            tokenizer=tokenizer,
            chat_template=hf_chat_template,
            chat_template_kwargs=kwargs,
        )
1574
1575
1576
1577
1578
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
            tokenize=tokenize,
1579
            **resolved_kwargs,
1580
        )
1581

1582
1583
1584
1585
1586
1587
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1588
1589
            "An error occurred in `transformers` while applying chat template"
        )
1590
        raise ValueError(str(e)) from e
1591

1592

1593
1594
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1595
    messages: list[ChatCompletionMessageParam],
1596
1597
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
1598
    **kwargs: Any,
1599
) -> list[int]:
1600
1601
    from mistral_common.exceptions import MistralCommonException

1602
1603
1604
1605
1606
1607
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1608

1609
1610
1611
1612
1613
1614
1615
1616
1617
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1618
    # properly caught in the preprocessing_input step
1619
    except (AssertionError, MistralCommonException) as e:
1620
        raise ValueError(str(e)) from e
1621
1622
1623
1624
1625
1626
1627

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1628
1629
            "An error occurred in `mistral_common` while applying chat template"
        )
1630
        raise ValueError(str(e)) from e
1631

1632

1633
1634
1635
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1636
1637
1638
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1639
1640
1641
    return idx


1642
1643
1644
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1645
1646
1647
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"