chat_utils.py 56.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
from abc import ABC, abstractmethod
7
from collections import Counter, defaultdict, deque
8
from collections.abc import Awaitable, Iterable
9
from functools import cached_property, lru_cache, partial
10
from pathlib import Path
11
12
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
                    cast)
13

14
15
16
import jinja2
import jinja2.ext
import jinja2.meta
17
import jinja2.nodes
18
19
import jinja2.parser
import jinja2.sandbox
20
import transformers.utils.chat_template_utils as hf_chat_utils
21
22
# yapf conflicts with isort for this block
# yapf: disable
23
from openai.types.chat import (ChatCompletionAssistantMessageParam,
24
25
                               ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartInputAudioParam)
26
27
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
28
29
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
                               ChatCompletionContentPartTextParam)
30
31
from openai.types.chat import (
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
32
33
from openai.types.chat import (ChatCompletionMessageToolCallParam,
                               ChatCompletionToolMessageParam)
34
35
from openai.types.chat.chat_completion_content_part_input_audio_param import (
    InputAudio)
36
from openai.types.responses import ResponseInputImageParam
37
from openai_harmony import Message as OpenAIHarmonyMessage
38
39
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
40
# yapf: enable
41
42
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
                          ProcessorMixin)
43
# pydantic needs the TypedDict from typing_extensions
44
from typing_extensions import Required, TypeAlias, TypedDict
45

46
from vllm.config import ModelConfig
47
from vllm.logger import init_logger
48
from vllm.model_executor.models import SupportsMultiModal
49
50
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalUUIDDict)
51
from vllm.multimodal.utils import MediaConnector
52
53
54
55
# yapf: disable
from vllm.transformers_utils.chat_templates import (
    get_chat_template_fallback_path)
# yapf: enable
56
from vllm.transformers_utils.processor import cached_get_processor
57
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
58
from vllm.utils import random_uuid, supports_kw
59
60
61

logger = init_logger(__name__)

62
63
64
65
66
67
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


83
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
84
    image_embeds: Optional[Union[str, dict[str, str]]]
85
86
87
88
89
90
91
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
92
93
94
95
96
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


113
114
115
116
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
117

118
119
120
121
122
123
124
125
126
127
128
129
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
130

131
    image_pil: Optional[PILImage]
132
133
134
135
136
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
137
138


139
140
141
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
142

143
144
145
146
147
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
148

149
    image_url: Optional[str]
150
151
152
153
154
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
155
156
157
158


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
159

160
161
162
163
164
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
165

166
    audio_url: Optional[str]
167
168


169
170
171
172
173
174
175
176
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
177

178
    video_url: Optional[str]
179
180
181
182
183
    uuid: Optional[str]
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
184
185


Julien Denize's avatar
Julien Denize committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


207
ChatCompletionContentPartParam: TypeAlias = Union[
208
209
    OpenAIChatCompletionContentPartParam,
    ChatCompletionContentPartAudioParam,
210
    ChatCompletionContentPartInputAudioParam,
211
212
    ChatCompletionContentPartVideoParam,
    ChatCompletionContentPartRefusalParam,
213
    CustomChatCompletionContentPILImageParam,
214
    CustomChatCompletionContentSimpleImageParam,
215
    ChatCompletionContentPartImageEmbedsParam,
216
    CustomChatCompletionContentSimpleAudioParam,
217
218
219
220
    CustomChatCompletionContentSimpleVideoParam,
    str,
    CustomThinkCompletionContentParam,
]
221
222
223
224


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
225

226
227
228
    role: Required[str]
    """The role of the message's author."""

229
    content: Union[str, list[ChatCompletionContentPartParam]]
230
231
232
233
234
235
236
237
238
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

239
240
241
242
243
244
    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""

245

246
247
248
249
250
ChatCompletionMessageParam = Union[
    OpenAIChatCompletionMessageParam,
    CustomChatCompletionMessageParam,
    OpenAIHarmonyMessage,
]
251
252


253
# TODO: Make fields ReadOnly once mypy supports it
254
255
256
257
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

258
    content: Union[Optional[str], list[dict[str, str]]]
259
260
261
262
263
264
265
266
267
268
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""
269
270


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
        return (_is_var_access(node.node, varname)
                and isinstance(node.arg, jinja2.nodes.Const)
                and node.arg.value == key)

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: Optional[str] = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
303
304
        return node.node is not None and _is_var_or_elems_access(
            node.node, varname, key)
305
306
307
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

308
309
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
            node.arg, jinja2.nodes.Slice):
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        return _is_var_or_elems_access(node.node, varname, key)

    # yapf: disable
    return (
        _is_attr_access(node, varname, key) if key
        else _is_var_access(node, varname)
    ) # yapf: enable


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
        varname
        for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


387
@lru_cache(maxsize=32)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


408
409
410
411
412
413
def resolve_mistral_chat_template(
    chat_template: Optional[str],
    **kwargs: Any,
) -> Optional[str]:
    if chat_template is not None:
        logger.warning_once(
414
415
            "'chat_template' cannot be overridden for mistral tokenizer."
        )
416
417
418
    if "add_generation_prompt" in kwargs:
        logger.warning_once(
            "'add_generation_prompt' is not supported for mistral tokenizer, "
419
420
            "so it will be ignored."
        )
421
422
423
    if "continue_final_message" in kwargs:
        logger.warning_once(
            "'continue_final_message' is not supported for mistral tokenizer, "
424
425
            "so it will be ignored."
        )
426
427
    return None

428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    model_config: ModelConfig,
) -> Optional[str]:
    cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
            trust_remote_code=model_config.trust_remote_code,
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


474
def resolve_hf_chat_template(
475
476
477
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
478
479
    *,
    model_config: ModelConfig,
480
481
482
483
484
485
486
) -> Optional[str]:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
487
488
489
490
        chat_template = _try_get_processor_chat_template(tokenizer,
                                                         model_config)
        if chat_template is not None:
            return chat_template
491
492
493
494
495

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
496
497
498
499
500
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
501

502
503
504
505
506
507
    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=model_config.tokenizer,
    )
    if path is not None:
508
509
510
511
512
        logger.info(
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
513
514
        chat_template = load_chat_template(path)
    else:
515
516
517
        logger.debug(
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
518
519

    return chat_template
520
521


522
523
def _resolve_chat_template_content_format(
    chat_template: Optional[str],
524
    tools: Optional[list[dict[str, Any]]],
525
    tokenizer: AnyTokenizer,
526
527
    *,
    model_config: ModelConfig,
528
529
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
530
        hf_chat_template = resolve_hf_chat_template(
531
532
533
            tokenizer,
            chat_template=chat_template,
            tools=tools,
534
            model_config=model_config,
535
        )
536
    else:
537
538
        hf_chat_template = None

539
540
541
542
543
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
544

545
546
547
548
549
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
550

551
    return detected_format
552
553
554


@lru_cache
555
def _log_chat_template_content_format(
556
557
    chat_template: Optional[str],
    given_format: ChatTemplateContentFormatOption,
558
559
    detected_format: ChatTemplateContentFormatOption,
):
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

577
578
579
580
581
582

def resolve_chat_template_content_format(
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
    given_format: ChatTemplateContentFormatOption,
    tokenizer: AnyTokenizer,
583
584
    *,
    model_config: ModelConfig,
585
) -> _ChatTemplateContentFormat:
586
587
588
    if given_format != "auto":
        return given_format

589
590
591
592
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
593
        model_config=model_config,
594
595
596
597
598
599
600
601
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

602
    return detected_format
603

604

605
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
606
607
608
609
_T = TypeVar("_T")


class BaseMultiModalItemTracker(ABC, Generic[_T]):
610
611
612
613
614
615
616
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

    def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
617
618
        super().__init__()

619
620
        self._model_config = model_config
        self._tokenizer = tokenizer
621

622
        self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
623
        self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
624

625
626
627
628
    @property
    def model_config(self) -> ModelConfig:
        return self._model_config

629
    @cached_property
630
    def model_cls(self) -> type[SupportsMultiModal]:
631
        from vllm.model_executor.model_loader import get_model_cls
632

633
634
        model_cls = get_model_cls(self.model_config)
        return cast(type[SupportsMultiModal], model_cls)
635

636
637
638
639
    @property
    def allowed_local_media_path(self):
        return self._model_config.allowed_local_media_path

640
641
642
643
    @property
    def allowed_media_domains(self):
        return self._model_config.allowed_media_domains

644
645
646
647
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

648
649
650
651
    @cached_property
    def mm_processor(self):
        return self.mm_registry.create_processor(self.model_config)

652
    def add(
653
654
655
656
        self,
        modality: ModalityStr,
        item: Optional[_T],
        uuid: Optional[str] = None,
657
    ) -> Optional[str]:
658
659
660
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
661
662

        An optional uuid can be added which serves as a unique identifier of the
663
        media.
664
        """
665
        input_modality = modality.replace("_embeds", "")
666
        num_items = len(self._items_by_modality[modality]) + 1
667

668
        self.mm_processor.validate_num_items(input_modality, num_items)
669

670
        self._items_by_modality[modality].append(item)
671
        self._uuids_by_modality[modality].append(uuid)
672

673
        return self.model_cls.get_placeholder_str(modality, num_items)
674

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
    def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
        if not self._items_by_modality:
            return None
        mm_uuids = {}
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )

        if "image_embeds" in uuids_by_modality:
            image_embeds_uuids = uuids_by_modality["image_embeds"]
            if len(image_embeds_uuids) > 1:
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
        return mm_uuids

700
701
702
703
704
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


705
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
706
    def all_mm_data(self) -> Optional[MultiModalDataDict]:
707
708
709
710
711
        if not self._items_by_modality:
            return None
        mm_inputs = {}
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
712
713
714
            raise ValueError(
                "Mixing raw image and embedding inputs is not allowed"
            )
715
716
717
718

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
719
720
721
                raise ValueError(
                    "Only one message can have {'type': 'image_embeds'}"
                )
722
            mm_inputs["image"] = image_embeds_lst[0]
723
        if "image" in items_by_modality:
724
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
725
        if "audio" in items_by_modality:
726
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
727
        if "video" in items_by_modality:
728
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
729
        return mm_inputs
730
731
732
733
734

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


735
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
736
    async def all_mm_data(self) -> Optional[MultiModalDataDict]:
737
738
739
        if not self._items_by_modality:
            return None
        mm_inputs = {}
740
741
742
743
744
745
746
747
748
        items_by_modality = {}
        for modality, items in self._items_by_modality.items():
            coros = []
            for item in items:
                if item is not None:
                    coros.append(item)
                else:
                    coros.append(asyncio.sleep(0))
            items_by_modality[modality] = await asyncio.gather(*coros)
749

750
751
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
            raise ValueError(
752
753
                "Mixing raw image and embedding inputs is not allowed"
            )
754
755
756
757
758

        if "image_embeds" in items_by_modality:
            image_embeds_lst = items_by_modality["image_embeds"]
            if len(image_embeds_lst) > 1:
                raise ValueError(
759
760
                    "Only one message can have {'type': 'image_embeds'}"
                )
761
            mm_inputs["image"] = image_embeds_lst[0]
762
        if "image" in items_by_modality:
763
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
764
        if "audio" in items_by_modality:
765
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
766
        if "video" in items_by_modality:
767
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
768
        return mm_inputs
769
770
771
772
773
774
775
776
777

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

778
        # stores model placeholders list with corresponding
779
780
781
782
783
784
785
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

786
787
788
    def _add_placeholder(
        self, modality: ModalityStr, placeholder: Optional[str]
    ):
789
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
790
        if placeholder:
791
            self._placeholder_storage[mod_placeholder].append(placeholder)
792

793
794
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
795
796

    @abstractmethod
797
798
    def parse_image(
        self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
799
800
        raise NotImplementedError

801
    @abstractmethod
802
    def parse_image_embeds(
803
        self,
804
        image_embeds: Union[str, dict[str, str], None],
805
        uuid: Optional[str] = None,
806
    ) -> None:
807
808
        raise NotImplementedError

809
    @abstractmethod
810
    def parse_image_pil(
811
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
812
    ) -> None:
813
814
        raise NotImplementedError

815
    @abstractmethod
816
817
818
    def parse_audio(
        self, audio_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
819
820
        raise NotImplementedError

821
    @abstractmethod
822
    def parse_input_audio(
823
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
824
    ) -> None:
825
826
        raise NotImplementedError

827
    @abstractmethod
828
829
830
    def parse_video(
        self, video_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
831
832
        raise NotImplementedError

833
834
835
836
837
838

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
839
840
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
841
        self._connector = MediaConnector(
842
            media_io_kwargs=media_io_kwargs,
843
            allowed_local_media_path=tracker.allowed_local_media_path,
844
            allowed_media_domains=tracker.allowed_media_domains,
845
846
        )

847
848
849
850
    def parse_image(
        self, image_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        image = self._connector.fetch_image(image_url) if image_url else None
851

852
        placeholder = self._tracker.add("image", image, uuid)
853
        self._add_placeholder("image", placeholder)
854

855
    def parse_image_embeds(
856
        self,
857
        image_embeds: Union[str, dict[str, str], None],
858
        uuid: Optional[str] = None,
859
    ) -> None:
860
861
862
863
864
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
865
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
866
867
868

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
869
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
870

871
872
873
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

874
        self._add_placeholder("image", placeholder)
875

876
    def parse_image_pil(
877
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
878
879
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
880
        self._add_placeholder("image", placeholder)
881

882
883
884
885
    def parse_audio(
        self, audio_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
886

887
        placeholder = self._tracker.add("audio", audio, uuid)
888
        self._add_placeholder("audio", placeholder)
889

890
    def parse_input_audio(
891
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
892
    ) -> None:
893
894
895
896
897
898
899
900
901
902
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
903

904
        return self.parse_audio(audio_url, uuid)
905

906
907
908
909
910
911
912
913
    def parse_video(
        self, video_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        video = (
            self._connector.fetch_video(video_url=video_url)
            if video_url
            else None
        )
914

915
        placeholder = self._tracker.add("video", video, uuid)
916
        self._add_placeholder("video", placeholder)
917

918
919
920
921
922
923

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
924
925
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
926
        self._connector = MediaConnector(
927
            media_io_kwargs=media_io_kwargs,
928
            allowed_local_media_path=tracker.allowed_local_media_path,
929
            allowed_media_domains=tracker.allowed_media_domains,
930
        )
931

932
933
934
935
936
937
    def parse_image(
        self, image_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        image_coro = (
            self._connector.fetch_image_async(image_url) if image_url else None
        )
938

939
        placeholder = self._tracker.add("image", image_coro, uuid)
940
        self._add_placeholder("image", placeholder)
941

942
    def parse_image_embeds(
943
        self,
944
        image_embeds: Union[str, dict[str, str], None],
945
        uuid: Optional[str] = None,
946
    ) -> None:
947
948
949
        future: asyncio.Future[Union[str, dict[str, str], None]] = (
            asyncio.Future()
        )
950
951
952
953
954
955
956
957
958

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
959
            embedding = self._connector.fetch_image_embedding(image_embeds)
960
961
            future.set_result(embedding)

962
963
964
        if image_embeds is None:
            future.set_result(None)

965
        placeholder = self._tracker.add("image_embeds", future, uuid)
966
        self._add_placeholder("image", placeholder)
967

968
    def parse_image_pil(
969
        self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
970
    ) -> None:
971
972
973
974
975
        future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
976

977
        placeholder = self._tracker.add("image", future, uuid)
978
        self._add_placeholder("image", placeholder)
979

980
981
982
983
984
985
    def parse_audio(
        self, audio_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        audio_coro = (
            self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
986

987
        placeholder = self._tracker.add("audio", audio_coro, uuid)
988
        self._add_placeholder("audio", placeholder)
989

990
    def parse_input_audio(
991
        self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
992
    ) -> None:
993
994
995
996
997
998
999
1000
1001
1002
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1003

1004
        return self.parse_audio(audio_url, uuid)
1005

1006
1007
1008
1009
1010
1011
1012
1013
    def parse_video(
        self, video_url: Optional[str], uuid: Optional[str] = None
    ) -> None:
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
1014

1015
        placeholder = self._tracker.add("video", video, uuid)
1016
        self._add_placeholder("video", placeholder)
1017

1018

1019
1020
1021
1022
1023
1024
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1025
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1026
1027
1028

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1029
1030
1031
1032
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1033
1034
            raise ValueError(
                f"The supplied chat template string ({chat_template}) "
1035
1036
                f"appears path-like, but doesn't exist!"
            )
1037
1038
1039

    else:
        raise TypeError(
1040
1041
            f"{type(chat_template)} is not a valid chat template type"
        )
1042
1043


1044
def _load_chat_template(
1045
1046
1047
1048
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
1049
1050
    if chat_template is None:
        return None
1051
1052
1053

    if is_literal:
        if isinstance(chat_template, Path):
1054
1055
1056
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1057

1058
        return chat_template
1059

1060
    try:
1061
        with open(chat_template) as f:
1062
            return f.read()
1063
    except OSError as e:
1064
1065
1066
        if isinstance(chat_template, Path):
            raise

1067
1068
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1069
1070
1071
1072
1073
            msg = (
                f"The supplied chat template ({chat_template}) "
                f"looks like a file path, but it failed to be "
                f"opened. Reason: {e}"
            )
1074
            raise ValueError(msg) from e
1075

1076
1077
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
    chat_template: Optional[Union[Path, str]],
    *,
    is_literal: bool = False,
) -> Optional[str]:
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1090
1091


1092
1093
1094
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1095
1096
1097
1098
1099
1100
1101
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1102
# TODO: Let user specify how to insert multimodal tokens into prompt
1103
# (similar to chat template)
1104
1105
1106
1107
1108
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1109
    """Combine multimodal prompts for a multimodal language model."""
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1128
    # Look through the text prompt to check for missing placeholders
1129
    missing_placeholders: list[str] = []
1130
1131
1132
1133
1134
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1135
1136
1137
1138
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1139
1140
                "when manually placing image placeholders.",
                interleave_strings,
1141
1142
            )
            logger.debug("Input prompt: %s", text_prompt)
1143
1144
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1145
1146
                "actual multimodal data items."
            )
1147

1148
1149
1150
        missing_placeholders.extend(
            [placeholder] * placeholder_counts[placeholder]
        )
1151

1152
1153
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1154
    return "\n".join(missing_placeholders + [text_prompt])
1155
1156


1157
1158
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1159
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1160
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1161
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1162
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1163
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1164
1165
1166
1167
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1168

1169
_ResponsesInputImageParser = TypeAdapter(
1170
1171
    ResponseInputImageParam
).validate_python
1172
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
1173

1174
# Define a mapping from part types to their corresponding parsing functions.
1175
MM_PARSER_MAP: dict[
1176
1177
1178
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
    "input_image": lambda part: _ResponsesInputImageParser(part).get(
        "image_url", None
    ),
    "image_url": lambda part: _ImageParser(part)
    .get("image_url", {})
    .get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get(
        "image_embeds", None
    ),
1191
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
    "audio_url": lambda part: _AudioParser(part)
    .get("audio_url", {})
    .get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get(
        "input_audio", None
    ),
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
    "video_url": lambda part: _VideoParser(part)
    .get("video_url", {})
    .get("url", None),
1202
1203
1204
1205
}


def _parse_chat_message_content_mm_part(
1206
1207
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1208
    """
1209
    Parses a given multi-modal content part based on its type.
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1223
1224
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1225
    part_type = part.get("type", None)
1226
    uuid = part.get("uuid", None)
1227

1228
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
1229
1230
1231
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1232
1233
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1234
1235
1236
1237
            logger.warning(
                "'image_url.detail' is currently not supported "
                "and will be ignored."
            )
1238
1239
1240
1241

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1242
    # 'type' is required field by pydantic
1243
1244
    if part_type is None or uuid is not None:
        if "image_url" in part:
1245
1246
1247
            image_params = cast(
                CustomChatCompletionContentSimpleImageParam, part
            )
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
            image_params = cast( # type: ignore 
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
            image_params = cast( # type: ignore 
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
        if "audio_url" in part:
1269
1270
1271
            audio_params = cast(
                CustomChatCompletionContentSimpleAudioParam, part
            )
1272
1273
1274
1275
1276
1277
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1278
        if part.get("input_audio") is not None:
1279
            input_audio_params = cast(dict[str, str], part)
1280
            return "input_audio", input_audio_params
1281
        if "video_url" in part:
1282
1283
1284
            video_params = cast(
                CustomChatCompletionContentSimpleVideoParam, part
            )
1285
1286
1287
1288
1289
1290
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1291
1292
1293
1294
1295
1296
1297
1298
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1299
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1300
1301
1302
    "text",
    "refusal",
)
1303

1304

1305
1306
1307
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1308
    mm_tracker: BaseMultiModalItemTracker,
1309
1310
    *,
    wrap_dicts: bool,
1311
    interleave_strings: bool,
1312
) -> list[ConversationMessage]:
1313
    content = list[_ContentPart]()
1314

1315
    mm_parser = mm_tracker.create_parser()
1316
1317

    for part in parts:
1318
        parse_res = _parse_chat_message_content_part(
1319
1320
1321
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1322
            interleave_strings=interleave_strings,
1323
        )
1324
1325
        if parse_res:
            content.append(parse_res)
1326

1327
    if wrap_dicts:
1328
        # Parsing wraps images and texts as interleaved dictionaries
1329
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1330
    texts = cast(list[str], content)
1331
1332
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1333
1334
1335
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1336
1337
1338
    else:
        text_prompt = "\n".join(texts)

1339
1340
1341
1342
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1343
1344
1345
1346
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1347
    interleave_strings: bool,
1348
) -> Optional[_ContentPart]:
1349
1350
1351
1352
1353
1354
1355
1356
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1357
        return part
1358
1359
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1360
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1361
    # content is None, log a warning and skip
1362
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1363
        logger.warning(
1364
            "Skipping multimodal part '%s' (type: '%s') "
1365
1366
1367
1368
            "with empty / unparsable content.",
            part,
            part_type,
        )
1369
1370
        return None

Julien Denize's avatar
Julien Denize committed
1371
    if part_type in ("text", "input_text", "refusal", "thinking"):
1372
1373
        str_content = cast(str, content)
        if wrap_dicts:
1374
            return {"type": "text", "text": str_content}
1375
1376
        else:
            return str_content
1377

1378
1379
1380
1381
1382
1383
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1384
    modality = None
1385
    if part_type == "image_pil":
1386
1387
1388
1389
        if content is not None:
            image_content = cast(Image.Image, content)
        else:
            image_content = None
1390
        mm_parser.parse_image_pil(image_content, uuid)
1391
        modality = "image"
1392
    elif part_type in ("image_url", "input_image"):
1393
        str_content = cast(str, content)
1394
        mm_parser.parse_image(str_content, uuid)
1395
1396
        modality = "image"
    elif part_type == "image_embeds":
1397
1398
1399
1400
        if content is not None:
            content = cast(Union[str, dict[str, str]], content)
        else:
            content = None
1401
        mm_parser.parse_image_embeds(content, uuid)
1402
1403
        modality = "image"
    elif part_type == "audio_url":
1404
        str_content = cast(str, content)
1405
        mm_parser.parse_audio(str_content, uuid)
1406
1407
        modality = "audio"
    elif part_type == "input_audio":
1408
        dict_content = cast(InputAudio, content)
1409
        mm_parser.parse_input_audio(dict_content, uuid)
1410
1411
        modality = "audio"
    elif part_type == "video_url":
1412
        str_content = cast(str, content)
1413
        mm_parser.parse_video(str_content, uuid)
1414
1415
1416
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1417

1418
1419
1420
1421
1422
1423
    return (
        {"type": modality}
        if wrap_dicts
        else (
            MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
        )
1424
    )
1425
1426


1427
1428
1429
1430
1431
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1432
def _parse_chat_message_content(
1433
1434
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1435
    content_format: _ChatTemplateContentFormat,
1436
    interleave_strings: bool,
1437
) -> list[ConversationMessage]:
1438
1439
1440
1441
    role = message["role"]
    content = message.get("content")

    if content is None:
1442
1443
1444
1445
1446
1447
        content = []
    elif isinstance(content, str):
        content = [
            ChatCompletionContentPartTextParam(type="text", text=content)
        ]
    result = _parse_chat_message_content_parts(
1448
1449
        role,
        content,  # type: ignore
1450
        mm_tracker,
1451
        wrap_dicts=(content_format == "openai"),
1452
        interleave_strings=interleave_strings,
1453
    )
1454

1455
    for result_msg in result:
1456
        if role == "assistant":
1457
1458
            parsed_msg = _AssistantParser(message)

1459
1460
1461
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1462
1463
1464
1465
            if (
                "tool_calls" in parsed_msg
                and parsed_msg["tool_calls"] is not None
            ):
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

    return result

1477

1478
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1479
1480
1481
1482
1483
1484
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1485
1486
1487
1488
1489
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
1490
            for item in message["tool_calls"]:
1491
1492
1493
1494
1495
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
                    item["function"]["arguments"] = json.loads(content)
                else:
                    item["function"]["arguments"] = {}
1496
1497


1498
def parse_chat_messages(
1499
    messages: list[ChatCompletionMessageParam],
1500
    model_config: ModelConfig,
1501
    tokenizer: AnyTokenizer,
1502
    content_format: _ChatTemplateContentFormat,
1503
1504
1505
1506
1507
) -> tuple[
    list[ConversationMessage],
    Optional[MultiModalDataDict],
    Optional[MultiModalUUIDDict],
]:
1508
    conversation: list[ConversationMessage] = []
1509
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)
1510
1511

    for msg in messages:
1512
1513
1514
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1515
            content_format,
1516
1517
1518
1519
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1520
            ),
1521
        )
1522

1523
        conversation.extend(sub_messages)
1524

1525
1526
    _postprocess_messages(conversation)

1527
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1528
1529


1530
def parse_chat_messages_futures(
1531
    messages: list[ChatCompletionMessageParam],
1532
1533
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
1534
    content_format: _ChatTemplateContentFormat,
1535
1536
1537
1538
1539
) -> tuple[
    list[ConversationMessage],
    Awaitable[Optional[MultiModalDataDict]],
    Optional[MultiModalUUIDDict],
]:
1540
    conversation: list[ConversationMessage] = []
1541
1542
1543
    mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)

    for msg in messages:
1544
1545
1546
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1547
            content_format,
1548
1549
1550
1551
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1552
            ),
1553
        )
1554
1555
1556

        conversation.extend(sub_messages)

1557
1558
    _postprocess_messages(conversation)

1559
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1560
1561


1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
def _resolve_chat_template_kwargs(
    chat_template: str,
):
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
    return template_vars


_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)


1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
def resolve_chat_template_kwargs(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
    fn_kw = {
        k for k in chat_template_kwargs
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }

1601
    template_vars = _cached_resolve_chat_template_kwargs(chat_template)
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611

    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template"}
    accept_vars = (fn_kw | template_vars) - unexpected_vars
    return {
        k: v for k, v in chat_template_kwargs.items() if k in accept_vars
    }


1612
1613
def apply_hf_chat_template(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1614
    conversation: list[ConversationMessage],
1615
    chat_template: Optional[str],
1616
    tools: Optional[list[dict[str, Any]]],
1617
    *,
1618
    model_config: ModelConfig,
1619
1620
    tokenize: bool = False,  # Different from HF's default
    **kwargs: Any,
1621
) -> str:
1622
    hf_chat_template = resolve_hf_chat_template(
1623
1624
1625
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1626
        model_config=model_config,
1627
    )
1628

1629
    if hf_chat_template is None:
1630
1631
1632
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1633
1634
            "does not define one."
        )
1635

1636
    try:
1637
1638
1639
1640
1641
        resolved_kwargs = resolve_chat_template_kwargs(
            tokenizer=tokenizer,
            chat_template=hf_chat_template,
            chat_template_kwargs=kwargs,
        )
1642
1643
1644
1645
1646
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
            tokenize=tokenize,
1647
            **resolved_kwargs,
1648
        )
1649

1650
1651
1652
1653
1654
1655
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1656
1657
            "An error occurred in `transformers` while applying chat template"
        )
1658
        raise ValueError(str(e)) from e
1659

1660

1661
1662
def apply_mistral_chat_template(
    tokenizer: MistralTokenizer,
1663
    messages: list[ChatCompletionMessageParam],
1664
1665
    chat_template: Optional[str],
    tools: Optional[list[dict[str, Any]]],
1666
    **kwargs: Any,
1667
) -> list[int]:
1668
1669
    from mistral_common.exceptions import MistralCommonException

1670
1671
1672
1673
1674
1675
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1676

1677
1678
1679
1680
1681
1682
1683
1684
1685
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1686
    # properly caught in the preprocessing_input step
1687
    except (AssertionError, MistralCommonException) as e:
1688
        raise ValueError(str(e)) from e
1689
1690
1691
1692
1693
1694
1695

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1696
1697
            "An error occurred in `mistral_common` while applying chat template"
        )
1698
        raise ValueError(str(e)) from e
1699

1700

1701
1702
1703
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1704
1705
1706
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1707
1708
1709
    return idx


1710
1711
1712
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1713
1714
1715
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"