chat_utils.py 50.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
from abc import ABC, abstractmethod
7
from collections import Counter, defaultdict, deque
8
from collections.abc import Awaitable, Iterable
9
from functools import cached_property, lru_cache, partial
10
from pathlib import Path
11
12
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
                    cast)
13

14
15
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
16
17
# yapf conflicts with isort for this block
# yapf: disable
18
from openai.types.chat import (ChatCompletionAssistantMessageParam,
19
20
                               ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartInputAudioParam)
21
22
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
23
24
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
                               ChatCompletionContentPartTextParam)
25
26
from openai.types.chat import (
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
27
28
from openai.types.chat import (ChatCompletionMessageToolCallParam,
                               ChatCompletionToolMessageParam)
29
30
from openai.types.chat.chat_completion_content_part_input_audio_param import (
    InputAudio)
31
from openai.types.responses import ResponseInputImageParam
32
from openai_harmony import Message as OpenAIHarmonyMessage
33
34
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
35
# yapf: enable
36
37
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
                          ProcessorMixin)
38
# pydantic needs the TypedDict from typing_extensions
39
from typing_extensions import Required, TypeAlias, TypedDict
40

41
from vllm.config import ModelConfig
42
from vllm.logger import init_logger
43
from vllm.model_executor.models import SupportsMultiModal
44
45
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalUUIDDict)
46
from vllm.multimodal.utils import MediaConnector
47
48
49
50
# yapf: disable
from vllm.transformers_utils.chat_templates import (
    get_chat_template_fallback_path)
# yapf: enable
51
from vllm.transformers_utils.processor import cached_get_processor
52
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
53
from vllm.utils import random_uuid
54
55
56

logger = init_logger(__name__)

57
58
59
60
61
62
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

63

64
65
66
67
68
69
70
71
72
73
74
75
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""
76
77
78
79
80
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
81
82


83
84
85
86
87
88
89
90
91
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
    image_embeds: Required[Union[str, dict[str, str]]]
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
92
93
94
95
96
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
97
98


99
100
101
102
103
104
105
106
107
108
109
110
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""
111
112
113
114
115
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
116
117


118
119
120
121
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
122

123
124
125
126
127
128
129
130
131
132
133
134
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
135

136
    image_pil: Required[PILImage]
137
138
139
140
141
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
142
143


144
145
146
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
147

148
149
150
151
152
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
153

154
    image_url: Required[str]
155
156
157
158
159
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
160
161
162
163


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
164

165
166
167
168
169
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
170

171
172
173
    audio_url: Required[str]


174
175
176
177
178
179
180
181
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
182

183
    video_url: Required[str]
184
185
186
187
188
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
189
190


Julien Denize's avatar
Julien Denize committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


212
ChatCompletionContentPartParam: TypeAlias = Union[
213
214
    OpenAIChatCompletionContentPartParam,
    ChatCompletionContentPartAudioParam,
215
    ChatCompletionContentPartInputAudioParam,
216
217
    ChatCompletionContentPartVideoParam,
    ChatCompletionContentPartRefusalParam,
218
    CustomChatCompletionContentPILImageParam,
219
    CustomChatCompletionContentSimpleImageParam,
220
    ChatCompletionContentPartImageEmbedsParam,
221
    CustomChatCompletionContentSimpleAudioParam,
222
223
224
225
    CustomChatCompletionContentSimpleVideoParam,
    str,
    CustomThinkCompletionContentParam,
]
226
227
228
229


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
230

231
232
233
    role: Required[str]
    """The role of the message's author."""

234
    content: Union[str, list[ChatCompletionContentPartParam]]
235
236
237
238
239
240
241
242
243
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

244
245
246
247
248
249
    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""

250

251
252
253
254
255
ChatCompletionMessageParam = Union[
    OpenAIChatCompletionMessageParam,
    CustomChatCompletionMessageParam,
    OpenAIHarmonyMessage,
]
256
257


258
# TODO: Make fields ReadOnly once mypy supports it
259
260
261
262
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

263
    content: Union[Optional[str], list[dict[str, str]]]
264
265
266
267
268
269
270
271
272
273
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""
274
275


276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
        return (_is_var_access(node.node, varname)
                and isinstance(node.arg, jinja2.nodes.Const)
                and node.arg.value == key)

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: Optional[str] = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
308
309
        return node.node is not None and _is_var_or_elems_access(
            node.node, varname, key)
310
311
312
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

313
314
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
            node.arg, jinja2.nodes.Slice):
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        return _is_var_or_elems_access(node.node, varname, key)

    # yapf: disable
    return (
        _is_attr_access(node, varname, key) if key
        else _is_var_access(node, varname)
    ) # yapf: enable


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
        varname
        for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


392
@lru_cache(maxsize=32)
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


413
414
415
416
417
418
def resolve_mistral_chat_template(
    chat_template: Optional[str],
    **kwargs: Any,
) -> Optional[str]:
    if chat_template is not None:
        logger.warning_once(
419
420
            "'chat_template' cannot be overridden for mistral tokenizer."
        )
421
422
423
    if "add_generation_prompt" in kwargs:
        logger.warning_once(
            "'add_generation_prompt' is not supported for mistral tokenizer, "
424
425
            "so it will be ignored."
        )
426
427
428
    if "continue_final_message" in kwargs:
        logger.warning_once(
            "'continue_final_message' is not supported for mistral tokenizer, "
429
430
            "so it will be ignored."
        )
431
432
    return None

433

434
def resolve_hf_chat_template(
435
436
437
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
438
439
    *,
    model_config: ModelConfig,
440
441
442
443
444
445
446
447
448
449
) -> Optional[str]:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
        try:
            processor = cached_get_processor(
                tokenizer.name_or_path,
450
451
452
453
454
                processor_cls=(
                    PreTrainedTokenizer,
                    PreTrainedTokenizerFast,
                    ProcessorMixin,
                ),
455
                trust_remote_code=model_config.trust_remote_code,
456
            )
457
458
459
460
461
            if (
                isinstance(processor, ProcessorMixin)
                and hasattr(processor, "chat_template")
                and processor.chat_template is not None
            ):
462
463
                return processor.chat_template
        except Exception:
464
465
466
467
468
            logger.debug(
                "Failed to load AutoProcessor chat template for %s",
                tokenizer.name_or_path,
                exc_info=True,
            )  # noqa: E501
469
470
471
472
473

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
474
475
476
477
478
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
479

480
481
482
483
484
485
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
486
487
488
489
490
        logger.info(
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
491
492
        chat_template = load_chat_template(path)
    else:
493
494
495
        logger.debug(
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
496
497

    return chat_template
498
499


500
501
def _resolve_chat_template_content_format(
    chat_template: Optional[str],
502
    tools: Optional[list[dict[str, Any]]],
503
    tokenizer: AnyTokenizer,
504
505
    *,
    model_config: ModelConfig,
506
507
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
508
        hf_chat_template = resolve_hf_chat_template(
509
510
511
            tokenizer,
            chat_template=chat_template,
            tools=tools,
512
            model_config=model_config,
513
        )
514
    else:
515
516
        hf_chat_template = None

517
518
519
520
521
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
522

523
524
525
526
527
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
528

529
    return detected_format
530
531
532


@lru_cache
533
def _log_chat_template_content_format(
534
535
    chat_template: Optional[str],
    given_format: ChatTemplateContentFormatOption,
536
537
    detected_format: ChatTemplateContentFormatOption,
):
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

555
556
557
558
559
560

def resolve_chat_template_content_format(
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
561
562
    *,
    model_config: ModelConfig,
563
) -> _ChatTemplateContentFormat:
564
565
566
    if given_format != "auto":
        return given_format

567
568
569
570
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
571
        model_config=model_config,
572
573
574
575
576
577
578
579
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

580
    return detected_format
581

582

583
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
584
585
586
587
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
588
589
590
591
592
593
594
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
595
596
        super().__init__()

597
598
        self._model_config = model_config
        self._tokenizer = tokenizer
599

600
        self._items_by_modality = defaultdict[str, list[_T]](list)
601
        self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
602

603
604
605
606
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

607
    @cached_property
608
    def model_cls(self) -> type[SupportsMultiModal]:
609
        from vllm.model_executor.model_loader import get_model_cls
610

611
612
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
613

614
615
616
617
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

618
619
620
621
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

622
623
624
625
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

626
627
628
    def add(
        self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
    ) -> Optional[str]:
629
630
631
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
632
633
634

        An optional uuid can be added which serves as a unique identifier of the
        media. 
635
        """
636
        input_modality = modality.replace("_embeds", "")
637
        num_items = len(self._items_by_modality[modality]) + 1
638

639
        self.mm_processor.validate_num_items(input_modality, num_items)
640

641
        self._items_by_modality[modality].append(item)
642
        self._uuids_by_modality[modality].append(uuid)
643

644
        return self.model_cls.get_placeholder_str(modality, num_items)
645

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
    def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
        if not self._items_by_modality:
            return None
        mm_uuids = {}
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )

        if "image_embeds" in uuids_by_modality:
            image_embeds_uuids = uuids_by_modality["image_embeds"]
            if len(image_embeds_uuids) > 1:
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
        return mm_uuids

671
672
673
674
675
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


676
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
677
    def all_mm_data(self) -> Optional[MultiModalDataDict]:
678
679
680
681
682
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
683
684
685
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )
686
687
688
689

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
690
691
692
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
693
            mm_inputs["image"] = image_embeds_lst[0]
694
        if "image" in items_by_modality:
695
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
696
        if "audio" in items_by_modality:
697
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
698
        if "video" in items_by_modality:
699
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
700
        return mm_inputs
701
702
703
704
705

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


706
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
707
    async def all_mm_data(self) -> Optional[MultiModalDataDict]:
708
709
710
711
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = {
712
713
714
            modality: await asyncio.gather(*items)
            for modality, items in self._items_by_modality.items()
        }
715

716
717
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
            raise ValueError(
718
719
                "Mixing raw image and embedding inputs is not allowed"
            )
720
721
722
723
724

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
                raise ValueError(
725
726
                    "Only one message can have {'type': 'image_embeds'}"
                )
727
            mm_inputs["image"] = image_embeds_lst[0]
728
        if "image" in items_by_modality:
729
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
730
        if "audio" in items_by_modality:
731
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
732
        if "video" in items_by_modality:
733
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
734
        return mm_inputs
735
736
737
738
739
740
741
742
743

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

744
        # stores model placeholders list with corresponding
745
746
747
748
749
750
751
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

752
753
754
    def _add_placeholder(
        self, modality: ModalityStr, placeholder: Optional[str]
    ):
755
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
756
        if placeholder:
757
            self._placeholder_storage[mod_placeholder].append(placeholder)
758

759
760
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
761
762

    @abstractmethod
763
    def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
764
765
        raise NotImplementedError

766
    @abstractmethod
767
    def parse_image_embeds(
768
769
770
        self,
        image_embeds: Union[str, dict[str, str]],
        uuid: Optional[str] = None,
771
    ) -> None:
772
773
        raise NotImplementedError

774
    @abstractmethod
775
776
777
    def parse_image_pil(
        self, image_pil: Image.Image, uuid: Optional[str] = None
    ) -> None:
778
779
        raise NotImplementedError

780
    @abstractmethod
781
    def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
782
783
        raise NotImplementedError

784
    @abstractmethod
785
786
787
    def parse_input_audio(
        self, input_audio: InputAudio, uuid: Optional[str] = None
    ) -> None:
788
789
        raise NotImplementedError

790
    @abstractmethod
791
    def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
792
793
        raise NotImplementedError

794
795
796
797
798
799
800

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker

801
        self._connector = MediaConnector(
802
            media_io_kwargs=self._tracker._model_config.media_io_kwargs,
803
804
805
            allowed_local_media_path=tracker.allowed_local_media_path,
        )

806
    def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
807
        image = self._connector.fetch_image(image_url)
808

809
        placeholder = self._tracker.add("image", image, uuid)
810
        self._add_placeholder("image", placeholder)
811

812
    def parse_image_embeds(
813
814
815
        self,
        image_embeds: Union[str, dict[str, str]],
        uuid: Optional[str] = None,
816
    ) -> None:
817
818
819
820
821
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
822
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
823
824
825

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
826
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
827

828
        self._add_placeholder("image", placeholder)
829

830
831
832
833
    def parse_image_pil(
        self, image_pil: Image.Image, uuid: Optional[str] = None
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
834
        self._add_placeholder("image", placeholder)
835

836
    def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
837
        audio = self._connector.fetch_audio(audio_url)
838

839
        placeholder = self._tracker.add("audio", audio, uuid)
840
        self._add_placeholder("audio", placeholder)
841

842
843
844
    def parse_input_audio(
        self, input_audio: InputAudio, uuid: Optional[str] = None
    ) -> None:
845
846
847
        audio_data = input_audio.get("data", "")
        audio_format = input_audio.get("format", "")
        audio_url = f"data:audio/{audio_format};base64,{audio_data}"
848

849
        return self.parse_audio(audio_url, uuid)
850

851
    def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
852
        video = self._connector.fetch_video(video_url=video_url)
853

854
        placeholder = self._tracker.add("video", video, uuid)
855
        self._add_placeholder("video", placeholder)
856

857
858
859
860
861
862

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
863
        self._connector = MediaConnector(
864
            media_io_kwargs=self._tracker._model_config.media_io_kwargs,
865
            allowed_local_media_path=tracker.allowed_local_media_path,
866
        )
867

868
    def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
869
        image_coro = self._connector.fetch_image_async(image_url)
870

871
        placeholder = self._tracker.add("image", image_coro, uuid)
872
        self._add_placeholder("image", placeholder)
873

874
    def parse_image_embeds(
875
876
877
        self,
        image_embeds: Union[str, dict[str, str]],
        uuid: Optional[str] = None,
878
    ) -> None:
879
880
881
882
883
884
885
886
887
888
        future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
889
            embedding = self._connector.fetch_image_embedding(image_embeds)
890
891
            future.set_result(embedding)

892
        placeholder = self._tracker.add("image_embeds", future, uuid)
893
        self._add_placeholder("image", placeholder)
894

895
896
897
    def parse_image_pil(
        self, image_pil: Image.Image, uuid: Optional[str] = None
    ) -> None:
898
899
900
        future: asyncio.Future[Image.Image] = asyncio.Future()
        future.set_result(image_pil)

901
        placeholder = self._tracker.add("image", future, uuid)
902
        self._add_placeholder("image", placeholder)
903

904
    def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
905
        audio_coro = self._connector.fetch_audio_async(audio_url)
906

907
        placeholder = self._tracker.add("audio", audio_coro, uuid)
908
        self._add_placeholder("audio", placeholder)
909

910
911
912
    def parse_input_audio(
        self, input_audio: InputAudio, uuid: Optional[str] = None
    ) -> None:
913
914
915
        audio_data = input_audio.get("data", "")
        audio_format = input_audio.get("format", "")
        audio_url = f"data:audio/{audio_format};base64,{audio_data}"
916

917
        return self.parse_audio(audio_url, uuid)
918

919
    def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
920
        video = self._connector.fetch_video_async(video_url=video_url)
921

922
        placeholder = self._tracker.add("video", video, uuid)
923
        self._add_placeholder("video", placeholder)
924

925

926
927
928
929
930
931
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
932
        raise FileNotFoundError("the supplied chat template path doesn't exist")
933
934
935

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
936
937
938
939
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
940
941
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
942
943
                f"appears path-like, but doesn't exist!"
            )
944
945
946

    else:
        raise TypeError(
947
948
            f"{type(chat_template)} is not a valid chat template type"
        )
949
950


951
def _load_chat_template(
952
953
954
955
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
956
957
    if chat_template is None:
        return None
958
959
960

    if is_literal:
        if isinstance(chat_template, Path):
961
962
963
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
964

965
        return chat_template
966

967
    try:
968
        with open(chat_template) as f:
969
            return f.read()
970
    except OSError as e:
971
972
973
        if isinstance(chat_template, Path):
            raise

974
975
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
976
977
978
979
980
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
981
            raise ValueError(msg) from e
982

983
984
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
985
986
987
988
989
990
991
992
993
994
995
996
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
997
998


999
1000
1001
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1002
1003
1004
1005
1006
1007
1008
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1009
# TODO: Let user specify how to insert multimodal tokens into prompt
1010
# (similar to chat template)
1011
1012
1013
1014
1015
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1016
    """Combine multimodal prompts for a multimodal language model."""
1017

1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1035
    # Look through the text prompt to check for missing placeholders
1036
    missing_placeholders: list[str] = []
1037
1038
1039
1040
1041
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1042
1043
1044
1045
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1046
1047
                "when manually placing image placeholders.",
                interleave_strings,
1048
1049
            )
            logger.debug("Input prompt: %s", text_prompt)
1050
1051
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1052
1053
                "actual multimodal data items."
            )
1054

1055
1056
1057
        missing_placeholders.extend(
            [placeholder] * placeholder_counts[placeholder]
        )
1058

1059
1060
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1061
    return "\n".join(missing_placeholders + [text_prompt])
1062
1063


1064
1065
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1066
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1067
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1068
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1069
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1070
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1071
1072
1073
1074
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1075

1076
_ResponsesInputImageParser = TypeAdapter(
1077
1078
    ResponseInputImageParam
).validate_python
1079
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
1080

1081
# Define a mapping from part types to their corresponding parsing functions.
1082
MM_PARSER_MAP: dict[
1083
1084
1085
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
    "input_image": lambda part: _ResponsesInputImageParser(part).get(
        "image_url", None
    ),
    "image_url": lambda part: _ImageParser(part)
    .get("image_url", {})
    .get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get(
        "image_embeds", None
    ),
1098
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
    "audio_url": lambda part: _AudioParser(part)
    .get("audio_url", {})
    .get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get(
        "input_audio", None
    ),
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
    "video_url": lambda part: _VideoParser(part)
    .get("video_url", {})
    .get("url", None),
1109
1110
1111
1112
}


def _parse_chat_message_content_mm_part(
1113
1114
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1115
    """
1116
    Parses a given multi-modal content part based on its type.
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1130
1131
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1132
1133
1134
1135
1136
1137
    part_type = part.get("type", None)

    if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1138
1139
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1140
1141
1142
1143
            logger.warning(
                "'image_url.detail' is currently not supported "
                "and will be ignored."
            )
1144
1145
1146
1147

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1148
    # 'type' is required field by pydantic
1149
1150
    if part_type is None:
        if part.get("image_url") is not None:
1151
1152
1153
            image_params = cast(
                CustomChatCompletionContentSimpleImageParam, part
            )
1154
1155
            return "image_url", image_params.get("image_url", "")
        if part.get("audio_url") is not None:
1156
1157
1158
            audio_params = cast(
                CustomChatCompletionContentSimpleAudioParam, part
            )
1159
            return "audio_url", audio_params.get("audio_url", "")
1160
        if part.get("input_audio") is not None:
1161
            input_audio_params = cast(dict[str, str], part)
1162
            return "input_audio", input_audio_params
1163
        if part.get("video_url") is not None:
1164
1165
1166
            video_params = cast(
                CustomChatCompletionContentSimpleVideoParam, part
            )
1167
            return "video_url", video_params.get("video_url", "")
1168
1169
1170
1171
1172
1173
1174
1175
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
    "text",
    "refusal",
    "image_url",
    "image_embeds",
    "image_pil",
    "audio_url",
    "input_audio",
    "video_url",
)
1186

1187

1188
1189
1190
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1191
    mm_tracker: BaseMultiModalItemTracker,
1192
1193
    *,
    wrap_dicts: bool,
1194
    interleave_strings: bool,
1195
) -> list[ConversationMessage]:
1196
    content = list[_ContentPart]()
1197

1198
    mm_parser = mm_tracker.create_parser()
1199
1200

    for part in parts:
1201
        parse_res = _parse_chat_message_content_part(
1202
1203
1204
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1205
            interleave_strings=interleave_strings,
1206
        )
1207
1208
        if parse_res:
            content.append(parse_res)
1209

1210
    if wrap_dicts:
1211
        # Parsing wraps images and texts as interleaved dictionaries
1212
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1213
    texts = cast(list[str], content)
1214
1215
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1216
1217
1218
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1219
1220
1221
    else:
        text_prompt = "\n".join(texts)

1222
1223
1224
1225
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1226
1227
1228
1229
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1230
    interleave_strings: bool,
1231
) -> Optional[_ContentPart]:
1232
1233
1234
1235
1236
1237
1238
1239
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1240
        return part
1241
1242
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1243
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1244
1245
    # content is None, log a warning and skip
    if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
1246
        logger.warning(
1247
            "Skipping multimodal part '%s' (type: '%s') "
1248
1249
1250
1251
            "with empty / unparsable content.",
            part,
            part_type,
        )
1252
1253
        return None

Julien Denize's avatar
Julien Denize committed
1254
    if part_type in ("text", "input_text", "refusal", "thinking"):
1255
1256
        str_content = cast(str, content)
        if wrap_dicts:
1257
            return {"type": "text", "text": str_content}
1258
1259
        else:
            return str_content
1260

1261
1262
1263
1264
1265
1266
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1267
    modality = None
1268
1269
    if part_type == "image_pil":
        image_content = cast(Image.Image, content)
1270
        mm_parser.parse_image_pil(image_content, uuid)
1271
        modality = "image"
1272
    elif part_type in ("image_url", "input_image"):
1273
        str_content = cast(str, content)
1274
        mm_parser.parse_image(str_content, uuid)
1275
1276
        modality = "image"
    elif part_type == "image_embeds":
1277
        content = cast(Union[str, dict[str, str]], content)
1278
        mm_parser.parse_image_embeds(content, uuid)
1279
1280
        modality = "image"
    elif part_type == "audio_url":
1281
        str_content = cast(str, content)
1282
        mm_parser.parse_audio(str_content, uuid)
1283
1284
        modality = "audio"
    elif part_type == "input_audio":
1285
        dict_content = cast(InputAudio, content)
1286
        mm_parser.parse_input_audio(dict_content, uuid)
1287
1288
        modality = "audio"
    elif part_type == "video_url":
1289
        str_content = cast(str, content)
1290
        mm_parser.parse_video(str_content, uuid)
1291
1292
1293
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1294

1295
1296
1297
1298
1299
1300
    return (
        {"type": modality}
        if wrap_dicts
        else (
            MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
        )
1301
    )
1302
1303


1304
1305
1306
1307
1308
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1309
def _parse_chat_message_content(
1310
1311
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1312
    content_format: _ChatTemplateContentFormat,
1313
    interleave_strings: bool,
1314
) -> list[ConversationMessage]:
1315
1316
1317
1318
    role = message["role"]
    content = message.get("content")

    if content is None:
1319
1320
1321
1322
1323
1324
        content = []
    elif isinstance(content, str):
        content = [
            ChatCompletionContentPartTextParam(type="text", text=content)
        ]
    result = _parse_chat_message_content_parts(
1325
1326
        role,
        content,  # type: ignore
1327
        mm_tracker,
1328
        wrap_dicts=(content_format == "openai"),
1329
        interleave_strings=interleave_strings,
1330
    )
1331

1332
    for result_msg in result:
1333
        if role == "assistant":
1334
1335
            parsed_msg = _AssistantParser(message)

1336
1337
1338
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1339
1340
1341
1342
            if (
                "tool_calls" in parsed_msg
                and parsed_msg["tool_calls"] is not None
            ):
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1354

1355
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1356
1357
1358
1359
1360
1361
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1362
1363
1364
1365
1366
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1367
1368
            for item in message["tool_calls"]:
                item["function"]["arguments"] = json.loads(
1369
1370
                    item["function"]["arguments"]
                )
1371
1372


1373
def parse_chat_messages(
1374
    messages: list[ChatCompletionMessageParam],
1375
    model_config: ModelConfig,
1376
    tokenizer: AnyTokenizer,
1377
    content_format: _ChatTemplateContentFormat,
1378
1379
1380
1381
1382
) -> tuple[
    list[ConversationMessage],
    Optional[MultiModalDataDict],
    Optional[MultiModalUUIDDict],
]:
1383
    conversation: list[ConversationMessage] = []
1384
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1385
1386

    for msg in messages:
1387
1388
1389
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1390
            content_format,
1391
1392
1393
1394
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1395
            ),
1396
        )
1397

1398
        conversation.extend(sub_messages)
1399

1400
1401
    _postprocess_messages(conversation)

1402
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1403
1404


1405
def parse_chat_messages_futures(
1406
    messages: list[ChatCompletionMessageParam],
1407
1408
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1409
    content_format: _ChatTemplateContentFormat,
1410
1411
1412
1413
1414
) -> tuple[
    list[ConversationMessage],
    Awaitable[Optional[MultiModalDataDict]],
    Optional[MultiModalUUIDDict],
]:
1415
    conversation: list[ConversationMessage] = []
1416
1417
1418
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1419
1420
1421
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1422
            content_format,
1423
1424
1425
1426
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1427
            ),
1428
        )
1429
1430
1431

        conversation.extend(sub_messages)

1432
1433
    _postprocess_messages(conversation)

1434
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1435
1436


1437
1438
def apply_hf_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1439
    conversation: list[ConversationMessage],
1440
    chat_template: Optional[str],
1441
    tools: Optional[list[dict[str, Any]]],
1442
    *,
1443
    model_config: ModelConfig,
1444
1445
    tokenize: bool = False,  # Different from HF's default
    **kwargs: Any,
1446
) -> str:
1447
    hf_chat_template = resolve_hf_chat_template(
1448
1449
1450
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1451
        model_config=model_config,
1452
    )
1453

1454
    if hf_chat_template is None:
1455
1456
1457
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1458
1459
            "does not define one."
        )
1460

1461
1462
1463
1464
1465
1466
1467
1468
    try:
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
            tokenize=tokenize,
            **kwargs,
        )
1469

1470
1471
1472
1473
1474
1475
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1476
1477
            "An error occurred in `transformers` while applying chat template"
        )
1478
        raise ValueError(str(e)) from e
1479

1480

1481
1482
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1483
    messages: list[ChatCompletionMessageParam],
1484
1485
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
1486
    **kwargs: Any,
1487
) -> list[int]:
1488
1489
    from mistral_common.exceptions import MistralCommonException

1490
1491
1492
1493
1494
1495
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1496

1497
1498
1499
1500
1501
1502
1503
1504
1505
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1506
    # properly caught in the preprocessing_input step
1507
    except (AssertionError, MistralCommonException) as e:
1508
        raise ValueError(str(e)) from e
1509
1510
1511
1512
1513
1514
1515

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1516
1517
            "An error occurred in `mistral_common` while applying chat template"
        )
1518
        raise ValueError(str(e)) from e
1519

1520

1521
1522
1523
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1524
1525
1526
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1527
1528
1529
    return idx


1530
1531
1532
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1533
1534
1535
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"