chat_utils.py 56.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
from abc import ABC, abstractmethod
7
from collections import Counter, defaultdict, deque
8
from collections.abc import Awaitable, Iterable
9
from functools import cached_property, lru_cache, partial
10
from pathlib import Path
11
12
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
                    cast)
13

14
15
16
import jinja2
import jinja2.ext
import jinja2.meta
17
import jinja2.nodes
18
19
import jinja2.parser
import jinja2.sandbox
20
import transformers.utils.chat_template_utils as hf_chat_utils
21
22
# yapf conflicts with isort for this block
# yapf: disable
23
from openai.types.chat import (ChatCompletionAssistantMessageParam,
24
25
                               ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartInputAudioParam)
26
27
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
28
29
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
                               ChatCompletionContentPartTextParam)
30
31
from openai.types.chat import (
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
32
33
from openai.types.chat import (ChatCompletionMessageToolCallParam,
                               ChatCompletionToolMessageParam)
34
35
from openai.types.chat.chat_completion_content_part_input_audio_param import (
    InputAudio)
36
from openai.types.responses import ResponseInputImageParam
37
from openai_harmony import Message as OpenAIHarmonyMessage
38
39
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
40
# yapf: enable
41
42
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
                          ProcessorMixin)
43
# pydantic needs the TypedDict from typing_extensions
44
from typing_extensions import Required, TypeAlias, TypedDict
45

46
from vllm.config import ModelConfig
47
from vllm.logger import init_logger
48
from vllm.model_executor.models import SupportsMultiModal
49
50
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalUUIDDict)
51
from vllm.multimodal.utils import MediaConnector
52
53
54
55
# yapf: disable
from vllm.transformers_utils.chat_templates import (
    get_chat_template_fallback_path)
# yapf: enable
56
from vllm.transformers_utils.processor import cached_get_processor
57
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
58
from vllm.utils import random_uuid, supports_kw
59
60
61

logger = init_logger(__name__)

62
63
64
65
66
67
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


83
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
84
    image_embeds: Optional[Union[str, dict[str, str]]]
85
86
87
88
89
90
91
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
92
93
94
95
96
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


113
114
115
116
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
117

118
119
120
121
122
123
124
125
126
127
128
129
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
130

131
    image_pil: Optional[PILImage]
132
133
134
135
136
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
137
138


139
140
141
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
142

143
144
145
146
147
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
148

149
    image_url: Optional[str]
150
151
152
153
154
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
155
156
157
158


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
159

160
161
162
163
164
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
165

166
    audio_url: Optional[str]
167
168


169
170
171
172
173
174
175
176
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
177

178
    video_url: Optional[str]
179
180
181
182
183
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
184
185


Julien Denize's avatar
Julien Denize committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


207
ChatCompletionContentPartParam: TypeAlias = Union[
208
209
    OpenAIChatCompletionContentPartParam,
    ChatCompletionContentPartAudioParam,
210
    ChatCompletionContentPartInputAudioParam,
211
212
    ChatCompletionContentPartVideoParam,
    ChatCompletionContentPartRefusalParam,
213
    CustomChatCompletionContentPILImageParam,
214
    CustomChatCompletionContentSimpleImageParam,
215
    ChatCompletionContentPartImageEmbedsParam,
216
    CustomChatCompletionContentSimpleAudioParam,
217
218
219
220
    CustomChatCompletionContentSimpleVideoParam,
    str,
    CustomThinkCompletionContentParam,
]
221
222
223
224


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
225

226
227
228
    role: Required[str]
    """The role of the message's author."""

229
    content: Union[str, list[ChatCompletionContentPartParam]]
230
231
232
233
234
235
236
237
238
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

239
240
241
242
243
244
    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""

245

246
247
248
249
250
ChatCompletionMessageParam = Union[
    OpenAIChatCompletionMessageParam,
    CustomChatCompletionMessageParam,
    OpenAIHarmonyMessage,
]
251
252


253
# TODO: Make fields ReadOnly once mypy supports it
254
255
256
257
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

258
    content: Union[Optional[str], list[dict[str, str]]]
259
260
261
262
263
264
265
266
267
268
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""
269
270


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
        return (_is_var_access(node.node, varname)
                and isinstance(node.arg, jinja2.nodes.Const)
                and node.arg.value == key)

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: Optional[str] = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
303
304
        return node.node is not None and _is_var_or_elems_access(
            node.node, varname, key)
305
306
307
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

308
309
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
            node.arg, jinja2.nodes.Slice):
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        return _is_var_or_elems_access(node.node, varname, key)

    # yapf: disable
    return (
        _is_attr_access(node, varname, key) if key
        else _is_var_access(node, varname)
    ) # yapf: enable


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
        varname
        for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


387
@lru_cache(maxsize=32)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


408
409
410
411
412
413
def resolve_mistral_chat_template(
    chat_template: Optional[str],
    **kwargs: Any,
) -> Optional[str]:
    if chat_template is not None:
        logger.warning_once(
414
415
            "'chat_template' cannot be overridden for mistral tokenizer."
        )
416
417
418
    if "add_generation_prompt" in kwargs:
        logger.warning_once(
            "'add_generation_prompt' is not supported for mistral tokenizer, "
419
420
            "so it will be ignored."
        )
421
422
423
    if "continue_final_message" in kwargs:
        logger.warning_once(
            "'continue_final_message' is not supported for mistral tokenizer, "
424
425
            "so it will be ignored."
        )
426
427
    return None

428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    model_config: ModelConfig,
) -> Optional[str]:
    cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
            trust_remote_code=model_config.trust_remote_code,
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


474
def resolve_hf_chat_template(
475
476
477
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
478
479
    *,
    model_config: ModelConfig,
480
481
482
483
484
485
486
) -> Optional[str]:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
487
488
489
490
        chat_template = _try_get_processor_chat_template(tokenizer,
                                                         model_config)
        if chat_template is not None:
            return chat_template
491
492
493
494
495

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
496
497
498
499
500
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
501

502
503
504
505
506
507
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
508
509
510
511
512
        logger.info(
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
513
514
        chat_template = load_chat_template(path)
    else:
515
516
517
        logger.debug(
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
518
519

    return chat_template
520
521


522
523
def _resolve_chat_template_content_format(
    chat_template: Optional[str],
524
    tools: Optional[list[dict[str, Any]]],
525
    tokenizer: AnyTokenizer,
526
527
    *,
    model_config: ModelConfig,
528
529
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
530
        hf_chat_template = resolve_hf_chat_template(
531
532
533
            tokenizer,
            chat_template=chat_template,
            tools=tools,
534
            model_config=model_config,
535
        )
536
    else:
537
538
        hf_chat_template = None

539
540
541
542
543
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
544

545
546
547
548
549
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
550

551
    return detected_format
552
553
554


@lru_cache
555
def _log_chat_template_content_format(
556
557
    chat_template: Optional[str],
    given_format: ChatTemplateContentFormatOption,
558
559
    detected_format: ChatTemplateContentFormatOption,
):
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

577
578
579
580
581
582

def resolve_chat_template_content_format(
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
583
584
    *,
    model_config: ModelConfig,
585
) -> _ChatTemplateContentFormat:
586
587
588
    if given_format != "auto":
        return given_format

589
590
591
592
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
593
        model_config=model_config,
594
595
596
597
598
599
600
601
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

602
    return detected_format
603

604

605
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
606
607
608
609
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
610
611
612
613
614
615
616
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
617
618
        super().__init__()

619
620
        self._model_config = model_config
        self._tokenizer = tokenizer
621

622
        self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
623
        self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
624

625
626
627
628
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

629
    @cached_property
630
    def model_cls(self) -> type[SupportsMultiModal]:
631
        from vllm.model_executor.model_loader import get_model_cls
632

633
634
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
635

636
637
638
639
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

640
641
642
643
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

644
645
646
647
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

648
    def add(
649
650
651
652
        self,
        modality: ModalityStr,
        item: Optional[_T],
        uuid: Optional[str] = None,
653
    ) -> Optional[str]:
654
655
656
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
657
658

        An optional uuid can be added which serves as a unique identifier of the
659
        media.
660
        """
661
        input_modality = modality.replace("_embeds", "")
662
        num_items = len(self._items_by_modality[modality]) + 1
663

664
        self.mm_processor.validate_num_items(input_modality, num_items)
665

666
        self._items_by_modality[modality].append(item)
667
        self._uuids_by_modality[modality].append(uuid)
668

669
        return self.model_cls.get_placeholder_str(modality, num_items)
670

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
        if not self._items_by_modality:
            return None
        mm_uuids = {}
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )

        if "image_embeds" in uuids_by_modality:
            image_embeds_uuids = uuids_by_modality["image_embeds"]
            if len(image_embeds_uuids) > 1:
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
        return mm_uuids

696
697
698
699
700
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


701
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
702
    def all_mm_data(self) -> Optional[MultiModalDataDict]:
703
704
705
706
707
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
708
709
710
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )
711
712
713
714

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
715
716
717
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
718
            mm_inputs["image"] = image_embeds_lst[0]
719
        if "image" in items_by_modality:
720
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
721
        if "audio" in items_by_modality:
722
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
723
        if "video" in items_by_modality:
724
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
725
        return mm_inputs
726
727
728
729
730

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


731
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
732
    async def all_mm_data(self) -> Optional[MultiModalDataDict]:
733
734
735
        if not self._items_by_modality:
            return None
        mm_inputs = {}
736
737
738
739
740
741
742
743
744
        items_by_modality = {}
        for modality, items in self._items_by_modality.items():
            coros = []
            for item in items:
                if item is not None:
                    coros.append(item)
                else:
                    coros.append(asyncio.sleep(0))
            items_by_modality[modality] = await asyncio.gather(*coros)
745

746
747
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
            raise ValueError(
748
749
                "Mixing raw image and embedding inputs is not allowed"
            )
750
751
752
753
754

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
                raise ValueError(
755
756
                    "Only one message can have {'type': 'image_embeds'}"
                )
757
            mm_inputs["image"] = image_embeds_lst[0]
758
        if "image" in items_by_modality:
759
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
760
        if "audio" in items_by_modality:
761
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
762
        if "video" in items_by_modality:
763
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
764
        return mm_inputs
765
766
767
768
769
770
771
772
773

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

774
        # stores model placeholders list with corresponding
775
776
777
778
779
780
781
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

782
783
784
    def _add_placeholder(
        self, modality: ModalityStr, placeholder: Optional[str]
    ):
785
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
786
        if placeholder:
787
            self._placeholder_storage[mod_placeholder].append(placeholder)
788

789
790
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
791
792

    @abstractmethod
793
794
    def parse_image(
        self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
795
796
        raise NotImplementedError

797
    @abstractmethod
798
    def parse_image_embeds(
799
        self,
800
        image_embeds: Union[str, dict[str, str], None],
801
        uuid: Optional[str] = None,
802
    ) -> None:
803
804
        raise NotImplementedError

805
    @abstractmethod
806
    def parse_image_pil(
807
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
808
    ) -> None:
809
810
        raise NotImplementedError

811
    @abstractmethod
812
813
814
    def parse_audio(
        self, audio_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
815
816
        raise NotImplementedError

817
    @abstractmethod
818
    def parse_input_audio(
819
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
820
    ) -> None:
821
822
        raise NotImplementedError

823
    @abstractmethod
824
825
826
    def parse_video(
        self, video_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
827
828
        raise NotImplementedError

829
830
831
832
833
834

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
835
836
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
837
        self._connector = MediaConnector(
838
            media_io_kwargs=media_io_kwargs,
839
840
841
            allowed_local_media_path=tracker.allowed_local_media_path,
        )

842
843
844
845
    def parse_image(
        self, image_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        image = self._connector.fetch_image(image_url) if image_url else None
846

847
        placeholder = self._tracker.add("image", image, uuid)
848
        self._add_placeholder("image", placeholder)
849

850
    def parse_image_embeds(
851
        self,
852
        image_embeds: Union[str, dict[str, str], None],
853
        uuid: Optional[str] = None,
854
    ) -> None:
855
856
857
858
859
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
860
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
861
862
863

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
864
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
865

866
867
868
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

869
        self._add_placeholder("image", placeholder)
870

871
    def parse_image_pil(
872
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
873
874
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
875
        self._add_placeholder("image", placeholder)
876

877
878
879
880
    def parse_audio(
        self, audio_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
881

882
        placeholder = self._tracker.add("audio", audio, uuid)
883
        self._add_placeholder("audio", placeholder)
884

885
    def parse_input_audio(
886
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
887
    ) -> None:
888
889
890
891
892
893
894
895
896
897
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
898

899
        return self.parse_audio(audio_url, uuid)
900

901
902
903
904
905
906
907
908
    def parse_video(
        self, video_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        video = (
            self._connector.fetch_video(video_url=video_url)
            if video_url
            else None
        )
909

910
        placeholder = self._tracker.add("video", video, uuid)
911
        self._add_placeholder("video", placeholder)
912

913
914
915
916
917
918

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
919
920
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
921
        self._connector = MediaConnector(
922
            media_io_kwargs=media_io_kwargs,
923
            allowed_local_media_path=tracker.allowed_local_media_path,
924
        )
925

926
927
928
929
930
931
    def parse_image(
        self, image_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        image_coro = (
            self._connector.fetch_image_async(image_url) if image_url else None
        )
932

933
        placeholder = self._tracker.add("image", image_coro, uuid)
934
        self._add_placeholder("image", placeholder)
935

936
    def parse_image_embeds(
937
        self,
938
        image_embeds: Union[str, dict[str, str], None],
939
        uuid: Optional[str] = None,
940
    ) -> None:
941
942
943
        future: asyncio.Future[Union[str, dict[str, str], None]] = (
            asyncio.Future()
        )
944
945
946
947
948
949
950
951
952

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
953
            embedding = self._connector.fetch_image_embedding(image_embeds)
954
955
            future.set_result(embedding)

956
957
958
        if image_embeds is None:
            future.set_result(None)

959
        placeholder = self._tracker.add("image_embeds", future, uuid)
960
        self._add_placeholder("image", placeholder)
961

962
    def parse_image_pil(
963
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
964
    ) -> None:
965
966
967
968
969
        future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
970

971
        placeholder = self._tracker.add("image", future, uuid)
972
        self._add_placeholder("image", placeholder)
973

974
975
976
977
978
979
    def parse_audio(
        self, audio_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        audio_coro = (
            self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
980

981
        placeholder = self._tracker.add("audio", audio_coro, uuid)
982
        self._add_placeholder("audio", placeholder)
983

984
    def parse_input_audio(
985
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
986
    ) -> None:
987
988
989
990
991
992
993
994
995
996
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
997

998
        return self.parse_audio(audio_url, uuid)
999

1000
1001
1002
1003
1004
1005
1006
1007
    def parse_video(
        self, video_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
1008

1009
        placeholder = self._tracker.add("video", video, uuid)
1010
        self._add_placeholder("video", placeholder)
1011

1012

1013
1014
1015
1016
1017
1018
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1019
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1020
1021
1022

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1023
1024
1025
1026
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1027
1028
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
1029
1030
                f"appears path-like, but doesn't exist!"
            )
1031
1032
1033

    else:
        raise TypeError(
1034
1035
            f"{type(chat_template)} is not a valid chat template type"
        )
1036
1037


1038
def _load_chat_template(
1039
1040
1041
1042
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
1043
1044
    if chat_template is None:
        return None
1045
1046
1047

    if is_literal:
        if isinstance(chat_template, Path):
1048
1049
1050
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1051

1052
        return chat_template
1053

1054
    try:
1055
        with open(chat_template) as f:
1056
            return f.read()
1057
    except OSError as e:
1058
1059
1060
        if isinstance(chat_template, Path):
            raise

1061
1062
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1063
1064
1065
1066
1067
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
1068
            raise ValueError(msg) from e
1069

1070
1071
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1084
1085


1086
1087
1088
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1089
1090
1091
1092
1093
1094
1095
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1096
# TODO: Let user specify how to insert multimodal tokens into prompt
1097
# (similar to chat template)
1098
1099
1100
1101
1102
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1103
    """Combine multimodal prompts for a multimodal language model."""
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1122
    # Look through the text prompt to check for missing placeholders
1123
    missing_placeholders: list[str] = []
1124
1125
1126
1127
1128
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1129
1130
1131
1132
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1133
1134
                "when manually placing image placeholders.",
                interleave_strings,
1135
1136
            )
            logger.debug("Input prompt: %s", text_prompt)
1137
1138
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1139
1140
                "actual multimodal data items."
            )
1141

1142
1143
1144
        missing_placeholders.extend(
            [placeholder] * placeholder_counts[placeholder]
        )
1145

1146
1147
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1148
    return "\n".join(missing_placeholders + [text_prompt])
1149
1150


1151
1152
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1153
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1154
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1155
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1156
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1157
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1158
1159
1160
1161
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1162

1163
_ResponsesInputImageParser = TypeAdapter(
1164
1165
    ResponseInputImageParam
).validate_python
1166
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
1167

1168
# Define a mapping from part types to their corresponding parsing functions.
1169
MM_PARSER_MAP: dict[
1170
1171
1172
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
    "input_image": lambda part: _ResponsesInputImageParser(part).get(
        "image_url", None
    ),
    "image_url": lambda part: _ImageParser(part)
    .get("image_url", {})
    .get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get(
        "image_embeds", None
    ),
1185
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
    "audio_url": lambda part: _AudioParser(part)
    .get("audio_url", {})
    .get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get(
        "input_audio", None
    ),
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
    "video_url": lambda part: _VideoParser(part)
    .get("video_url", {})
    .get("url", None),
1196
1197
1198
1199
}


def _parse_chat_message_content_mm_part(
1200
1201
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1202
    """
1203
    Parses a given multi-modal content part based on its type.
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1217
1218
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1219
    part_type = part.get("type", None)
1220
    uuid = part.get("uuid", None)
1221

1222
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
1223
1224
1225
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1226
1227
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1228
1229
1230
1231
            logger.warning(
                "'image_url.detail' is currently not supported "
                "and will be ignored."
            )
1232
1233
1234
1235

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1236
    # 'type' is required field by pydantic
1237
1238
    if part_type is None or uuid is not None:
        if "image_url" in part:
1239
1240
1241
            image_params = cast(
                CustomChatCompletionContentSimpleImageParam, part
            )
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
            image_params = cast( # type: ignore 
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
            image_params = cast( # type: ignore 
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
        if "audio_url" in part:
1263
1264
1265
            audio_params = cast(
                CustomChatCompletionContentSimpleAudioParam, part
            )
1266
1267
1268
1269
1270
1271
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1272
        if part.get("input_audio") is not None:
1273
            input_audio_params = cast(dict[str, str], part)
1274
            return "input_audio", input_audio_params
1275
        if "video_url" in part:
1276
1277
1278
            video_params = cast(
                CustomChatCompletionContentSimpleVideoParam, part
            )
1279
1280
1281
1282
1283
1284
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1285
1286
1287
1288
1289
1290
1291
1292
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1293
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1294
1295
1296
    "text",
    "refusal",
)
1297

1298

1299
1300
1301
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1302
    mm_tracker: BaseMultiModalItemTracker,
1303
1304
    *,
    wrap_dicts: bool,
1305
    interleave_strings: bool,
1306
) -> list[ConversationMessage]:
1307
    content = list[_ContentPart]()
1308

1309
    mm_parser = mm_tracker.create_parser()
1310
1311

    for part in parts:
1312
        parse_res = _parse_chat_message_content_part(
1313
1314
1315
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1316
            interleave_strings=interleave_strings,
1317
        )
1318
1319
        if parse_res:
            content.append(parse_res)
1320

1321
    if wrap_dicts:
1322
        # Parsing wraps images and texts as interleaved dictionaries
1323
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1324
    texts = cast(list[str], content)
1325
1326
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1327
1328
1329
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1330
1331
1332
    else:
        text_prompt = "\n".join(texts)

1333
1334
1335
1336
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1337
1338
1339
1340
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1341
    interleave_strings: bool,
1342
) -> Optional[_ContentPart]:
1343
1344
1345
1346
1347
1348
1349
1350
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1351
        return part
1352
1353
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1354
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1355
    # content is None, log a warning and skip
1356
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1357
        logger.warning(
1358
            "Skipping multimodal part '%s' (type: '%s') "
1359
1360
1361
1362
            "with empty / unparsable content.",
            part,
            part_type,
        )
1363
1364
        return None

Julien Denize's avatar
Julien Denize committed
1365
    if part_type in ("text", "input_text", "refusal", "thinking"):
1366
1367
        str_content = cast(str, content)
        if wrap_dicts:
1368
            return {"type": "text", "text": str_content}
1369
1370
        else:
            return str_content
1371

1372
1373
1374
1375
1376
1377
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1378
    modality = None
1379
    if part_type == "image_pil":
1380
1381
1382
1383
        if content is not None:
            image_content = cast(Image.Image, content)
        else:
            image_content = None
1384
        mm_parser.parse_image_pil(image_content, uuid)
1385
        modality = "image"
1386
    elif part_type in ("image_url", "input_image"):
1387
        str_content = cast(str, content)
1388
        mm_parser.parse_image(str_content, uuid)
1389
1390
        modality = "image"
    elif part_type == "image_embeds":
1391
1392
1393
1394
        if content is not None:
            content = cast(Union[str, dict[str, str]], content)
        else:
            content = None
1395
        mm_parser.parse_image_embeds(content, uuid)
1396
1397
        modality = "image"
    elif part_type == "audio_url":
1398
        str_content = cast(str, content)
1399
        mm_parser.parse_audio(str_content, uuid)
1400
1401
        modality = "audio"
    elif part_type == "input_audio":
1402
        dict_content = cast(InputAudio, content)
1403
        mm_parser.parse_input_audio(dict_content, uuid)
1404
1405
        modality = "audio"
    elif part_type == "video_url":
1406
        str_content = cast(str, content)
1407
        mm_parser.parse_video(str_content, uuid)
1408
1409
1410
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1411

1412
1413
1414
1415
1416
1417
    return (
        {"type": modality}
        if wrap_dicts
        else (
            MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
        )
1418
    )
1419
1420


1421
1422
1423
1424
1425
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1426
def _parse_chat_message_content(
1427
1428
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1429
    content_format: _ChatTemplateContentFormat,
1430
    interleave_strings: bool,
1431
) -> list[ConversationMessage]:
1432
1433
1434
1435
    role = message["role"]
    content = message.get("content")

    if content is None:
1436
1437
1438
1439
1440
1441
        content = []
    elif isinstance(content, str):
        content = [
            ChatCompletionContentPartTextParam(type="text", text=content)
        ]
    result = _parse_chat_message_content_parts(
1442
1443
        role,
        content,  # type: ignore
1444
        mm_tracker,
1445
        wrap_dicts=(content_format == "openai"),
1446
        interleave_strings=interleave_strings,
1447
    )
1448

1449
    for result_msg in result:
1450
        if role == "assistant":
1451
1452
            parsed_msg = _AssistantParser(message)

1453
1454
1455
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1456
1457
1458
1459
            if (
                "tool_calls" in parsed_msg
                and parsed_msg["tool_calls"] is not None
            ):
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1471

1472
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1473
1474
1475
1476
1477
1478
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1479
1480
1481
1482
1483
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1484
            for item in message["tool_calls"]:
1485
1486
1487
1488
1489
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
                    item["function"]["arguments"] = json.loads(content)
                else:
                    item["function"]["arguments"] = {}
1490
1491


1492
def parse_chat_messages(
1493
    messages: list[ChatCompletionMessageParam],
1494
    model_config: ModelConfig,
1495
    tokenizer: AnyTokenizer,
1496
    content_format: _ChatTemplateContentFormat,
1497
1498
1499
1500
1501
) -> tuple[
    list[ConversationMessage],
    Optional[MultiModalDataDict],
    Optional[MultiModalUUIDDict],
]:
1502
    conversation: list[ConversationMessage] = []
1503
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1504
1505

    for msg in messages:
1506
1507
1508
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1509
            content_format,
1510
1511
1512
1513
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1514
            ),
1515
        )
1516

1517
        conversation.extend(sub_messages)
1518

1519
1520
    _postprocess_messages(conversation)

1521
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1522
1523


1524
def parse_chat_messages_futures(
1525
    messages: list[ChatCompletionMessageParam],
1526
1527
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1528
    content_format: _ChatTemplateContentFormat,
1529
1530
1531
1532
1533
) -> tuple[
    list[ConversationMessage],
    Awaitable[Optional[MultiModalDataDict]],
    Optional[MultiModalUUIDDict],
]:
1534
    conversation: list[ConversationMessage] = []
1535
1536
1537
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1538
1539
1540
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1541
            content_format,
1542
1543
1544
1545
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1546
            ),
1547
        )
1548
1549
1550

        conversation.extend(sub_messages)

1551
1552
    _postprocess_messages(conversation)

1553
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1554
1555


1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


def resolve_chat_template_kwargs(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
    fn_kw = {
        k for k in chat_template_kwargs
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }

    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)

    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template"}
    accept_vars = (fn_kw | template_vars) - unexpected_vars
    return {
        k: v for k, v in chat_template_kwargs.items() if k in accept_vars
    }


1596
1597
def apply_hf_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1598
    conversation: list[ConversationMessage],
1599
    chat_template: Optional[str],
1600
    tools: Optional[list[dict[str, Any]]],
1601
    *,
1602
    model_config: ModelConfig,
1603
1604
    tokenize: bool = False,  # Different from HF's default
    **kwargs: Any,
1605
) -> str:
1606
    hf_chat_template = resolve_hf_chat_template(
1607
1608
1609
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1610
        model_config=model_config,
1611
    )
1612

1613
    if hf_chat_template is None:
1614
1615
1616
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1617
1618
            "does not define one."
        )
1619

1620
    try:
1621
1622
1623
1624
1625
        resolved_kwargs = resolve_chat_template_kwargs(
            tokenizer=tokenizer,
            chat_template=hf_chat_template,
            chat_template_kwargs=kwargs,
        )
1626
1627
1628
1629
1630
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
            tokenize=tokenize,
1631
            **resolved_kwargs,
1632
        )
1633

1634
1635
1636
1637
1638
1639
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1640
1641
            "An error occurred in `transformers` while applying chat template"
        )
1642
        raise ValueError(str(e)) from e
1643

1644

1645
1646
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1647
    messages: list[ChatCompletionMessageParam],
1648
1649
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
1650
    **kwargs: Any,
1651
) -> list[int]:
1652
1653
    from mistral_common.exceptions import MistralCommonException

1654
1655
1656
1657
1658
1659
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1660

1661
1662
1663
1664
1665
1666
1667
1668
1669
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1670
    # properly caught in the preprocessing_input step
1671
    except (AssertionError, MistralCommonException) as e:
1672
        raise ValueError(str(e)) from e
1673
1674
1675
1676
1677
1678
1679

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1680
1681
            "An error occurred in `mistral_common` while applying chat template"
        )
1682
        raise ValueError(str(e)) from e
1683

1684

1685
1686
1687
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1688
1689
1690
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1691
1692
1693
    return idx


1694
1695
1696
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1697
1698
1699
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"