chat_utils.py 47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
from abc import ABC, abstractmethod
7
from collections import Counter, defaultdict, deque
8
from collections.abc import Awaitable, Iterable
9
from functools import cached_property, lru_cache, partial
10
from pathlib import Path
11
12
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
                    cast)
13

14
15
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
16
17
# yapf conflicts with isort for this block
# yapf: disable
18
from openai.types.chat import (ChatCompletionAssistantMessageParam,
19
20
                               ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartInputAudioParam)
21
22
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
23
24
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
                               ChatCompletionContentPartTextParam)
25
26
from openai.types.chat import (
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
27
28
from openai.types.chat import (ChatCompletionMessageToolCallParam,
                               ChatCompletionToolMessageParam)
29
30
from openai.types.chat.chat_completion_content_part_input_audio_param import (
    InputAudio)
31
from openai.types.responses import ResponseInputImageParam
32
from openai_harmony import Message as OpenAIHarmonyMessage
33
34
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
35
# yapf: enable
36
37
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
                          ProcessorMixin)
38
# pydantic needs the TypedDict from typing_extensions
39
from typing_extensions import Required, TypeAlias, TypedDict
40

41
from vllm.config import ModelConfig
42
from vllm.logger import init_logger
43
from vllm.model_executor.models import SupportsMultiModal
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
45
from vllm.multimodal.utils import MediaConnector
46
47
48
49
# yapf: disable
from vllm.transformers_utils.chat_templates import (
    get_chat_template_fallback_path)
# yapf: enable
50
from vllm.transformers_utils.processor import cached_get_processor
51
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
52
from vllm.utils import random_uuid
53
54
55

logger = init_logger(__name__)

56
57
58
59
60
61
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


77
78
79
80
81
82
83
84
85
86
87
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
    image_embeds: Required[Union[str, dict[str, str]]]
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""


88
89
90
91
92
93
94
95
96
97
98
99
100
101
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


102
103
104
105
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
106

107
108
109
110
111
112
113
114
115
116
117
118
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
119

120
121
122
    image_pil: Required[PILImage]


123
124
125
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
126

127
128
129
130
131
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
132

133
134
135
136
137
    image_url: Required[str]


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
138

139
140
141
142
143
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
144

145
146
147
    audio_url: Required[str]


148
149
150
151
152
153
154
155
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
156

157
158
159
    video_url: Required[str]


Julien Denize's avatar
Julien Denize committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


181
ChatCompletionContentPartParam: TypeAlias = Union[
182
183
    OpenAIChatCompletionContentPartParam,
    ChatCompletionContentPartAudioParam,
184
    ChatCompletionContentPartInputAudioParam,
185
186
    ChatCompletionContentPartVideoParam,
    ChatCompletionContentPartRefusalParam,
187
    CustomChatCompletionContentPILImageParam,
188
    CustomChatCompletionContentSimpleImageParam,
189
    ChatCompletionContentPartImageEmbedsParam,
190
    CustomChatCompletionContentSimpleAudioParam,
191
192
193
194
    CustomChatCompletionContentSimpleVideoParam,
    str,
    CustomThinkCompletionContentParam,
]
195
196
197
198


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
199

200
201
202
    role: Required[str]
    """The role of the message's author."""

203
    content: Union[str, list[ChatCompletionContentPartParam]]
204
205
206
207
208
209
210
211
212
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

213
214
215
216
217
218
    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""

219

220
221
222
223
224
ChatCompletionMessageParam = Union[
    OpenAIChatCompletionMessageParam,
    CustomChatCompletionMessageParam,
    OpenAIHarmonyMessage,
]
225
226


227
# TODO: Make fields ReadOnly once mypy supports it
228
229
230
231
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

232
    content: Union[Optional[str], list[dict[str, str]]]
233
234
235
236
237
238
239
240
241
242
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""
243
244


245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
        return (_is_var_access(node.node, varname)
                and isinstance(node.arg, jinja2.nodes.Const)
                and node.arg.value == key)

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: Optional[str] = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
277
278
        return node.node is not None and _is_var_or_elems_access(
            node.node, varname, key)
279
280
281
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

282
283
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
            node.arg, jinja2.nodes.Slice):
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        return _is_var_or_elems_access(node.node, varname, key)

    # yapf: disable
    return (
        _is_attr_access(node, varname, key) if key
        else _is_var_access(node, varname)
    ) # yapf: enable


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
        varname
        for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


361
@lru_cache(maxsize=32)
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


382
383
384
385
386
387
def resolve_mistral_chat_template(
    chat_template: Optional[str],
    **kwargs: Any,
) -> Optional[str]:
    if chat_template is not None:
        logger.warning_once(
388
389
            "'chat_template' cannot be overridden for mistral tokenizer."
        )
390
391
392
    if "add_generation_prompt" in kwargs:
        logger.warning_once(
            "'add_generation_prompt' is not supported for mistral tokenizer, "
393
394
            "so it will be ignored."
        )
395
396
397
    if "continue_final_message" in kwargs:
        logger.warning_once(
            "'continue_final_message' is not supported for mistral tokenizer, "
398
399
            "so it will be ignored."
        )
400
401
    return None

402

403
def resolve_hf_chat_template(
404
405
406
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
407
408
    *,
    model_config: ModelConfig,
409
410
411
412
413
414
415
416
417
418
) -> Optional[str]:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
        try:
            processor = cached_get_processor(
                tokenizer.name_or_path,
419
420
421
422
423
                processor_cls=(
                    PreTrainedTokenizer,
                    PreTrainedTokenizerFast,
                    ProcessorMixin,
                ),
424
                trust_remote_code=model_config.trust_remote_code,
425
            )
426
427
428
429
430
            if (
                isinstance(processor, ProcessorMixin)
                and hasattr(processor, "chat_template")
                and processor.chat_template is not None
            ):
431
432
                return processor.chat_template
        except Exception:
433
434
435
436
437
            logger.debug(
                "Failed to load AutoProcessor chat template for %s",
                tokenizer.name_or_path,
                exc_info=True,
            )  # noqa: E501
438
439
440
441
442

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
443
444
445
446
447
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
448

449
450
451
452
453
454
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
455
456
457
458
459
        logger.info(
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
460
461
        chat_template = load_chat_template(path)
    else:
462
463
464
        logger.debug(
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
465
466

    return chat_template
467
468


469
470
def _resolve_chat_template_content_format(
    chat_template: Optional[str],
471
    tools: Optional[list[dict[str, Any]]],
472
    tokenizer: AnyTokenizer,
473
474
    *,
    model_config: ModelConfig,
475
476
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
477
        hf_chat_template = resolve_hf_chat_template(
478
479
480
            tokenizer,
            chat_template=chat_template,
            tools=tools,
481
            model_config=model_config,
482
        )
483
    else:
484
485
        hf_chat_template = None

486
487
488
489
490
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
491

492
493
494
495
496
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
497

498
    return detected_format
499
500
501


@lru_cache
502
def _log_chat_template_content_format(
503
504
    chat_template: Optional[str],
    given_format: ChatTemplateContentFormatOption,
505
506
    detected_format: ChatTemplateContentFormatOption,
):
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

524
525
526
527
528
529

def resolve_chat_template_content_format(
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
530
531
    *,
    model_config: ModelConfig,
532
) -> _ChatTemplateContentFormat:
533
534
535
    if given_format != "auto":
        return given_format

536
537
538
539
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
540
        model_config=model_config,
541
542
543
544
545
546
547
548
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

549
    return detected_format
550

551

552
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
553
554
555
556
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
557
558
559
560
561
562
563
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
564
565
        super().__init__()

566
567
        self._model_config = model_config
        self._tokenizer = tokenizer
568

569
        self._items_by_modality = defaultdict[str, list[_T]](list)
570

571
572
573
574
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

575
    @cached_property
576
    def model_cls(self) -> type[SupportsMultiModal]:
577
        from vllm.model_executor.model_loader import get_model_cls
578

579
580
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
581

582
583
584
585
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

586
587
588
589
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

590
591
592
593
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

594
595
596
597
598
    def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
        """
599
        input_modality = modality.replace("_embeds", "")
600
        num_items = len(self._items_by_modality[modality]) + 1
601

602
        self.mm_processor.validate_num_items(input_modality, num_items)
603

604
        self._items_by_modality[modality].append(item)
605

606
        return self.model_cls.get_placeholder_str(modality, num_items)
607
608
609
610
611
612

    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


613
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
614
    def all_mm_data(self) -> Optional[MultiModalDataDict]:
615
616
617
618
619
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
620
621
622
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )
623
624
625
626

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
627
628
629
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
630
            mm_inputs["image"] = image_embeds_lst[0]
631
        if "image" in items_by_modality:
632
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
633
        if "audio" in items_by_modality:
634
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
635
        if "video" in items_by_modality:
636
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
637
        return mm_inputs
638
639
640
641
642

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


643
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
644
    async def all_mm_data(self) -> Optional[MultiModalDataDict]:
645
646
647
648
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = {
649
650
651
            modality: await asyncio.gather(*items)
            for modality, items in self._items_by_modality.items()
        }
652

653
654
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
            raise ValueError(
655
656
                "Mixing raw image and embedding inputs is not allowed"
            )
657
658
659
660
661

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
                raise ValueError(
662
663
                    "Only one message can have {'type': 'image_embeds'}"
                )
664
            mm_inputs["image"] = image_embeds_lst[0]
665
        if "image" in items_by_modality:
666
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
667
        if "audio" in items_by_modality:
668
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
669
        if "video" in items_by_modality:
670
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
671
        return mm_inputs
672
673
674
675
676
677
678
679
680

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

681
        # stores model placeholders list with corresponding
682
683
684
685
686
687
688
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

689
690
691
    def _add_placeholder(
        self, modality: ModalityStr, placeholder: Optional[str]
    ):
692
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
693
        if placeholder:
694
            self._placeholder_storage[mod_placeholder].append(placeholder)
695

696
697
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
698
699
700
701
702

    @abstractmethod
    def parse_image(self, image_url: str) -> None:
        raise NotImplementedError

703
    @abstractmethod
704
705
706
    def parse_image_embeds(
        self, image_embeds: Union[str, dict[str, str]]
    ) -> None:
707
708
        raise NotImplementedError

709
710
711
712
    @abstractmethod
    def parse_image_pil(self, image_pil: Image.Image) -> None:
        raise NotImplementedError

713
714
715
716
    @abstractmethod
    def parse_audio(self, audio_url: str) -> None:
        raise NotImplementedError

717
    @abstractmethod
718
    def parse_input_audio(self, input_audio: InputAudio) -> None:
719
720
        raise NotImplementedError

721
722
723
724
    @abstractmethod
    def parse_video(self, video_url: str) -> None:
        raise NotImplementedError

725
726
727
728
729
730
731

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker

732
        self._connector = MediaConnector(
733
            media_io_kwargs=self._tracker._model_config.media_io_kwargs,
734
735
736
            allowed_local_media_path=tracker.allowed_local_media_path,
        )

737
    def parse_image(self, image_url: str) -> None:
738
        image = self._connector.fetch_image(image_url)
739
740

        placeholder = self._tracker.add("image", image)
741
        self._add_placeholder("image", placeholder)
742

743
744
745
    def parse_image_embeds(
        self, image_embeds: Union[str, dict[str, str]]
    ) -> None:
746
747
748
749
750
751
752
753
754
755
756
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            placeholder = self._tracker.add("image_embeds", embeds)

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
            placeholder = self._tracker.add("image_embeds", embedding)

757
        self._add_placeholder("image", placeholder)
758

759
760
    def parse_image_pil(self, image_pil: Image.Image) -> None:
        placeholder = self._tracker.add("image", image_pil)
761
        self._add_placeholder("image", placeholder)
762

763
    def parse_audio(self, audio_url: str) -> None:
764
        audio = self._connector.fetch_audio(audio_url)
765
766

        placeholder = self._tracker.add("audio", audio)
767
        self._add_placeholder("audio", placeholder)
768

769
770
771
772
    def parse_input_audio(self, input_audio: InputAudio) -> None:
        audio_data = input_audio.get("data", "")
        audio_format = input_audio.get("format", "")
        audio_url = f"data:audio/{audio_format};base64,{audio_data}"
773

774
        return self.parse_audio(audio_url)
775

776
    def parse_video(self, video_url: str) -> None:
777
        video = self._connector.fetch_video(video_url=video_url)
778
779

        placeholder = self._tracker.add("video", video)
780
        self._add_placeholder("video", placeholder)
781

782
783
784
785
786
787

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
788
        self._connector = MediaConnector(
789
            media_io_kwargs=self._tracker._model_config.media_io_kwargs,
790
            allowed_local_media_path=tracker.allowed_local_media_path,
791
        )
792
793

    def parse_image(self, image_url: str) -> None:
794
        image_coro = self._connector.fetch_image_async(image_url)
795
796

        placeholder = self._tracker.add("image", image_coro)
797
        self._add_placeholder("image", placeholder)
798

799
800
801
    def parse_image_embeds(
        self, image_embeds: Union[str, dict[str, str]]
    ) -> None:
802
803
804
805
806
807
808
809
810
811
        future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
812
            embedding = self._connector.fetch_image_embedding(image_embeds)
813
814
815
            future.set_result(embedding)

        placeholder = self._tracker.add("image_embeds", future)
816
        self._add_placeholder("image", placeholder)
817

818
819
820
821
822
    def parse_image_pil(self, image_pil: Image.Image) -> None:
        future: asyncio.Future[Image.Image] = asyncio.Future()
        future.set_result(image_pil)

        placeholder = self._tracker.add("image", future)
823
        self._add_placeholder("image", placeholder)
824

825
    def parse_audio(self, audio_url: str) -> None:
826
        audio_coro = self._connector.fetch_audio_async(audio_url)
827
828

        placeholder = self._tracker.add("audio", audio_coro)
829
        self._add_placeholder("audio", placeholder)
830

831
832
833
834
    def parse_input_audio(self, input_audio: InputAudio) -> None:
        audio_data = input_audio.get("data", "")
        audio_format = input_audio.get("format", "")
        audio_url = f"data:audio/{audio_format};base64,{audio_data}"
835

836
        return self.parse_audio(audio_url)
837

838
    def parse_video(self, video_url: str) -> None:
839
        video = self._connector.fetch_video_async(video_url=video_url)
840
841

        placeholder = self._tracker.add("video", video)
842
        self._add_placeholder("video", placeholder)
843

844

845
846
847
848
849
850
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
851
        raise FileNotFoundError("the supplied chat template path doesn't exist")
852
853
854

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
855
856
857
858
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
859
860
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
861
862
                f"appears path-like, but doesn't exist!"
            )
863
864
865

    else:
        raise TypeError(
866
867
            f"{type(chat_template)} is not a valid chat template type"
        )
868
869


870
def _load_chat_template(
871
872
873
874
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
875
876
    if chat_template is None:
        return None
877
878
879

    if is_literal:
        if isinstance(chat_template, Path):
880
881
882
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
883

884
        return chat_template
885

886
    try:
887
        with open(chat_template) as f:
888
            return f.read()
889
    except OSError as e:
890
891
892
        if isinstance(chat_template, Path):
            raise

893
894
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
895
896
897
898
899
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
900
            raise ValueError(msg) from e
901

902
903
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
904
905
906
907
908
909
910
911
912
913
914
915
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
916
917


918
919
920
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
921
922
923
924
925
926
927
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


928
# TODO: Let user specify how to insert multimodal tokens into prompt
929
# (similar to chat template)
930
931
932
933
934
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
935
    """Combine multimodal prompts for a multimodal language model."""
936

937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

954
    # Look through the text prompt to check for missing placeholders
955
    missing_placeholders: list[str] = []
956
957
958
959
960
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
961
962
963
964
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
965
966
                "when manually placing image placeholders.",
                interleave_strings,
967
968
            )
            logger.debug("Input prompt: %s", text_prompt)
969
970
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
971
972
                "actual multimodal data items."
            )
973

974
975
976
        missing_placeholders.extend(
            [placeholder] * placeholder_counts[placeholder]
        )
977

978
979
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
980
    return "\n".join(missing_placeholders + [text_prompt])
981
982


983
984
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
985
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
986
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
987
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
988
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
989
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
990
991
992
993
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
994

995
_ResponsesInputImageParser = TypeAdapter(
996
997
    ResponseInputImageParam
).validate_python
998
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
999

1000
# Define a mapping from part types to their corresponding parsing functions.
1001
MM_PARSER_MAP: dict[
1002
1003
1004
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
    "input_image": lambda part: _ResponsesInputImageParser(part).get(
        "image_url", None
    ),
    "image_url": lambda part: _ImageParser(part)
    .get("image_url", {})
    .get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get(
        "image_embeds", None
    ),
1017
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    "audio_url": lambda part: _AudioParser(part)
    .get("audio_url", {})
    .get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get(
        "input_audio", None
    ),
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
    "video_url": lambda part: _VideoParser(part)
    .get("video_url", {})
    .get("url", None),
1028
1029
1030
1031
}


def _parse_chat_message_content_mm_part(
1032
1033
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1034
    """
1035
    Parses a given multi-modal content part based on its type.
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1049
1050
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1051
1052
1053
1054
1055
1056
    part_type = part.get("type", None)

    if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1057
1058
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1059
1060
1061
1062
            logger.warning(
                "'image_url.detail' is currently not supported "
                "and will be ignored."
            )
1063
1064
1065
1066

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1067
    # 'type' is required field by pydantic
1068
1069
    if part_type is None:
        if part.get("image_url") is not None:
1070
1071
1072
            image_params = cast(
                CustomChatCompletionContentSimpleImageParam, part
            )
1073
1074
            return "image_url", image_params.get("image_url", "")
        if part.get("audio_url") is not None:
1075
1076
1077
            audio_params = cast(
                CustomChatCompletionContentSimpleAudioParam, part
            )
1078
            return "audio_url", audio_params.get("audio_url", "")
1079
        if part.get("input_audio") is not None:
1080
            input_audio_params = cast(dict[str, str], part)
1081
            return "input_audio", input_audio_params
1082
        if part.get("video_url") is not None:
1083
1084
1085
            video_params = cast(
                CustomChatCompletionContentSimpleVideoParam, part
            )
1086
            return "video_url", video_params.get("video_url", "")
1087
1088
1089
1090
1091
1092
1093
1094
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
    "text",
    "refusal",
    "image_url",
    "image_embeds",
    "image_pil",
    "audio_url",
    "input_audio",
    "video_url",
)
1105

1106

1107
1108
1109
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1110
    mm_tracker: BaseMultiModalItemTracker,
1111
1112
    *,
    wrap_dicts: bool,
1113
    interleave_strings: bool,
1114
) -> list[ConversationMessage]:
1115
    content = list[_ContentPart]()
1116

1117
    mm_parser = mm_tracker.create_parser()
1118
1119

    for part in parts:
1120
        parse_res = _parse_chat_message_content_part(
1121
1122
1123
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1124
            interleave_strings=interleave_strings,
1125
        )
1126
1127
        if parse_res:
            content.append(parse_res)
1128

1129
    if wrap_dicts:
1130
        # Parsing wraps images and texts as interleaved dictionaries
1131
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1132
    texts = cast(list[str], content)
1133
1134
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1135
1136
1137
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1138
1139
1140
    else:
        text_prompt = "\n".join(texts)

1141
1142
1143
1144
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1145
1146
1147
1148
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1149
    interleave_strings: bool,
1150
) -> Optional[_ContentPart]:
1151
1152
1153
1154
1155
1156
1157
1158
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1159
        return part
1160
1161
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1162
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1163
1164
    # content is None, log a warning and skip
    if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
1165
        logger.warning(
1166
            "Skipping multimodal part '%s' (type: '%s') "
1167
1168
1169
1170
            "with empty / unparsable content.",
            part,
            part_type,
        )
1171
1172
        return None

Julien Denize's avatar
Julien Denize committed
1173
    if part_type in ("text", "input_text", "refusal", "thinking"):
1174
1175
        str_content = cast(str, content)
        if wrap_dicts:
1176
            return {"type": "text", "text": str_content}
1177
1178
        else:
            return str_content
1179

1180
    modality = None
1181
1182
1183
    if part_type == "image_pil":
        image_content = cast(Image.Image, content)
        mm_parser.parse_image_pil(image_content)
1184
        modality = "image"
1185
    elif part_type in ("image_url", "input_image"):
1186
1187
        str_content = cast(str, content)
        mm_parser.parse_image(str_content)
1188
1189
        modality = "image"
    elif part_type == "image_embeds":
1190
1191
        content = cast(Union[str, dict[str, str]], content)
        mm_parser.parse_image_embeds(content)
1192
1193
        modality = "image"
    elif part_type == "audio_url":
1194
1195
        str_content = cast(str, content)
        mm_parser.parse_audio(str_content)
1196
1197
        modality = "audio"
    elif part_type == "input_audio":
1198
        dict_content = cast(InputAudio, content)
1199
        mm_parser.parse_input_audio(dict_content)
1200
1201
        modality = "audio"
    elif part_type == "video_url":
1202
1203
        str_content = cast(str, content)
        mm_parser.parse_video(str_content)
1204
1205
1206
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1207

1208
1209
1210
1211
1212
1213
    return (
        {"type": modality}
        if wrap_dicts
        else (
            MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
        )
1214
    )
1215
1216


1217
1218
1219
1220
1221
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1222
def _parse_chat_message_content(
1223
1224
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1225
    content_format: _ChatTemplateContentFormat,
1226
    interleave_strings: bool,
1227
) -> list[ConversationMessage]:
1228
1229
1230
1231
    role = message["role"]
    content = message.get("content")

    if content is None:
1232
1233
1234
1235
1236
1237
        content = []
    elif isinstance(content, str):
        content = [
            ChatCompletionContentPartTextParam(type="text", text=content)
        ]
    result = _parse_chat_message_content_parts(
1238
1239
        role,
        content,  # type: ignore
1240
        mm_tracker,
1241
        wrap_dicts=(content_format == "openai"),
1242
        interleave_strings=interleave_strings,
1243
    )
1244

1245
    for result_msg in result:
1246
        if role == "assistant":
1247
1248
            parsed_msg = _AssistantParser(message)

1249
1250
1251
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1252
1253
1254
1255
            if (
                "tool_calls" in parsed_msg
                and parsed_msg["tool_calls"] is not None
            ):
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1267

1268
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1269
1270
1271
1272
1273
1274
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1275
1276
1277
1278
1279
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1280
1281
            for item in message["tool_calls"]:
                item["function"]["arguments"] = json.loads(
1282
1283
                    item["function"]["arguments"]
                )
1284
1285


1286
def parse_chat_messages(
1287
    messages: list[ChatCompletionMessageParam],
1288
    model_config: ModelConfig,
1289
    tokenizer: AnyTokenizer,
1290
    content_format: _ChatTemplateContentFormat,
1291
1292
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]:
    conversation: list[ConversationMessage] = []
1293
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1294
1295

    for msg in messages:
1296
1297
1298
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1299
            content_format,
1300
1301
1302
1303
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1304
            ),
1305
        )
1306

1307
        conversation.extend(sub_messages)
1308

1309
1310
    _postprocess_messages(conversation)

1311
    return conversation, mm_tracker.all_mm_data()
1312
1313


1314
def parse_chat_messages_futures(
1315
    messages: list[ChatCompletionMessageParam],
1316
1317
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1318
    content_format: _ChatTemplateContentFormat,
1319
1320
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
    conversation: list[ConversationMessage] = []
1321
1322
1323
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1324
1325
1326
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1327
            content_format,
1328
1329
1330
1331
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1332
            ),
1333
        )
1334
1335
1336

        conversation.extend(sub_messages)

1337
1338
    _postprocess_messages(conversation)

1339
1340
1341
    return conversation, mm_tracker.all_mm_data()


1342
1343
def apply_hf_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1344
    conversation: list[ConversationMessage],
1345
    chat_template: Optional[str],
1346
    tools: Optional[list[dict[str, Any]]],
1347
    *,
1348
    model_config: ModelConfig,
1349
1350
    tokenize: bool = False,  # Different from HF's default
    **kwargs: Any,
1351
) -> str:
1352
    hf_chat_template = resolve_hf_chat_template(
1353
1354
1355
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1356
        model_config=model_config,
1357
    )
1358

1359
    if hf_chat_template is None:
1360
1361
1362
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1363
1364
            "does not define one."
        )
1365

1366
1367
1368
1369
1370
1371
1372
1373
    try:
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
            tokenize=tokenize,
            **kwargs,
        )
1374

1375
1376
1377
1378
1379
1380
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1381
1382
            "An error occurred in `transformers` while applying chat template"
        )
1383
        raise ValueError(str(e)) from e
1384

1385

1386
1387
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1388
    messages: list[ChatCompletionMessageParam],
1389
1390
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
1391
    **kwargs: Any,
1392
) -> list[int]:
1393
1394
    from mistral_common.exceptions import MistralCommonException

1395
1396
1397
1398
1399
1400
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1401

1402
1403
1404
1405
1406
1407
1408
1409
1410
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1411
    # properly caught in the preprocessing_input step
1412
    except (AssertionError, MistralCommonException) as e:
1413
        raise ValueError(str(e)) from e
1414
1415
1416
1417
1418
1419
1420

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1421
1422
            "An error occurred in `mistral_common` while applying chat template"
        )
1423
        raise ValueError(str(e)) from e
1424

1425

1426
1427
1428
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1429
1430
1431
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1432
1433
1434
    return idx


1435
1436
1437
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1438
1439
1440
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"