chat_utils.py 50.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from itertools import accumulate
12
from pathlib import Path
13
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
14

15
from openai.types.chat import (
16
17
18
19
20
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
21
    ChatCompletionFunctionToolParam,
22
23
24
25
26
27
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
28
from openai.types.chat import (
29
30
31
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
32
from openai.types.responses import ResponseInputImageParam
33
from openai_harmony import Message as OpenAIHarmonyMessage
34
35
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
36

37
# pydantic needs the TypedDict from typing_extensions
38
from typing_extensions import Required, TypedDict
39

40
from vllm import envs
41
from vllm.config import ModelConfig
42
from vllm.logger import init_logger
43
from vllm.model_executor.models import SupportsMultiModal
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
45
46
47
48
49
50
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
51
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
52
from vllm.utils import random_uuid
53
54
55
56
57
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
58
    import transformers
59
else:
60
    transformers = LazyLoader("transformers", globals(), "transformers")
61
    torch = LazyLoader("torch", globals(), "torch")
62
63
64

logger = init_logger(__name__)

65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


83
84
85
86
87
88
89
90
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


91
92
93
94
95
96
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


112
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
113
    image_embeds: str | dict[str, str] | None
114
115
116
117
118
119
120
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
121
    uuid: str | None
122
123
124
125
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
126
127


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


144
145
146
147
148
149
150
151
152
153
154
155
156
157
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


158
159
160
161
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
162

163
164
165
166
167
168
169
170
171
172
173
174
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
175

176
177
    image_pil: PILImage | None
    uuid: str | None
178
179
180
181
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
182
183


184
185
186
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
187

188
189
190
191
192
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
193

194
195
    image_url: str | None
    uuid: str | None
196
197
198
199
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
200
201
202
203


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
204

205
206
207
208
209
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
210

211
    audio_url: str | None
212
213


214
215
216
217
218
219
220
221
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
222

223
224
    video_url: str | None
    uuid: str | None
225
226
227
228
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
229
230


Julien Denize's avatar
Julien Denize committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


252
253
254
255
256
257
258
259
260
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
261
    | ChatCompletionContentPartAudioEmbedsParam
262
263
264
265
266
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
267
268
269
270


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
271

272
273
274
    role: Required[str]
    """The role of the message's author."""

275
    content: str | list[ChatCompletionContentPartParam]
276
277
278
279
280
281
282
283
284
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

285
    tool_call_id: str | None
286
287
    """Tool call that this message is responding to."""

288
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
289
290
    """The tool calls generated by the model, such as function calls."""

291
292
293
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

294
295
296
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

297

298
299
300
301
302
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
303
304


305
# TODO: Make fields ReadOnly once mypy supports it
306
307
308
309
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

310
    content: str | None | list[dict[str, str]]
311
312
    """The contents of the message"""

313
    tool_call_id: str | None
314
315
    """Tool call that this message is responding to."""

316
    name: str | None
317
318
    """The name of the function to call"""

319
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
320
    """The tool calls generated by the model, such as function calls."""
321

322
323
324
325
326
327
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

328
329
330
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

331

332
333
334
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

335
336
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
337

338

339
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
340
341
342
_T = TypeVar("_T")


343
344
345
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
346
347


348
349
350
351
352
353
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
354

355
356
357
358
359
360
361
362
363
364
365
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
391

392
393
394
395
396
397
398
399
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
416

417
418
419
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
420
            )
421
422
423
424
425
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
426

427
    return data_merged
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
450
451


452
class BaseMultiModalItemTracker(ABC, Generic[_T]):
453
454
455
456
457
458
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

459
    def __init__(self, model_config: ModelConfig):
460
461
        super().__init__()

462
        self._model_config = model_config
463

464
        self._items_by_modality = defaultdict[str, list[_T]](list)
465

466
    @property
467
468
    def model_config(self) -> ModelConfig:
        return self._model_config
469

470
    @cached_property
471
    def model_cls(self) -> type[SupportsMultiModal]:
472
        from vllm.model_executor.model_loader import get_model_cls
473

474
        model_cls = get_model_cls(self.model_config)
475
        return cast(type[SupportsMultiModal], model_cls)
476

477
478
    @property
    def allowed_local_media_path(self):
479
        return self._model_config.allowed_local_media_path
480

481
482
    @property
    def allowed_media_domains(self):
483
        return self._model_config.allowed_media_domains
484

485
486
487
488
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

489
490
    @cached_property
    def mm_processor(self):
491
        return self.mm_registry.create_processor(self.model_config)
492

493
    def add(self, modality: ModalityStr, item: _T) -> str | None:
494
495
496
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
497
498

        An optional uuid can be added which serves as a unique identifier of the
499
        media.
500
        """
501
        input_modality = modality.replace("_embeds", "")
502
        num_items = len(self._items_by_modality[modality]) + 1
503

504
        self.mm_processor.validate_num_items(input_modality, num_items)
505

506
        self._items_by_modality[modality].append(item)
507

508
        return self.model_cls.get_placeholder_str(modality, num_items)
509
510
511
512
513
514

    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
557
        if not self._items_by_modality:
558
559
560
            return None, None

        return _resolve_items(dict(self._items_by_modality), self.mm_processor)
561
562
563
564
565

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


566
567
568
569
570
571
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
572
        if not self._items_by_modality:
573
            return None, None
574

575
        resolved_items_by_modality = {
576
            modality: await asyncio.gather(*coros)
577
            for modality, coros in self._items_by_modality.items()
578
        }
579
580

        return _resolve_items(resolved_items_by_modality, self.mm_processor)
581
582
583
584
585
586
587
588
589

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

590
        # stores model placeholders list with corresponding
591
592
593
594
595
596
597
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

598
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
599
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
600
        if placeholder:
601
            self._placeholder_storage[mod_placeholder].append(placeholder)
602

603
604
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
605
606

    @abstractmethod
607
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
608
609
        raise NotImplementedError

610
    @abstractmethod
611
    def parse_image_embeds(
612
        self,
613
614
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
615
    ) -> None:
616
617
        raise NotImplementedError

618
    @abstractmethod
619
    def parse_image_pil(
620
        self, image_pil: Image.Image | None, uuid: str | None = None
621
    ) -> None:
622
623
        raise NotImplementedError

624
    @abstractmethod
625
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
626
627
        raise NotImplementedError

628
    @abstractmethod
629
    def parse_input_audio(
630
        self, input_audio: InputAudio | None, uuid: str | None = None
631
    ) -> None:
632
633
        raise NotImplementedError

634
635
636
637
638
639
640
641
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

642
    @abstractmethod
643
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
644
645
        raise NotImplementedError

646
647
648
649
650
651

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
652
653
654
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

655
656
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
657
            media_io_kwargs=media_io_kwargs,
658
            allowed_local_media_path=tracker.allowed_local_media_path,
659
            allowed_media_domains=tracker.allowed_media_domains,
660
661
        )

662
663
    @property
    def model_config(self) -> ModelConfig:
664
        return self._tracker.model_config
665

666
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
667
        image = self._connector.fetch_image(image_url) if image_url else None
668

669
        placeholder = self._tracker.add("image", (image, uuid))
670
        self._add_placeholder("image", placeholder)
671

672
    def parse_image_embeds(
673
        self,
674
675
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
676
    ) -> None:
677
678
679
680
681
682
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

683
684
685
686
687
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
688
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
689
690
691

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
692
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
693

694
        if image_embeds is None:
695
            placeholder = self._tracker.add("image_embeds", (None, uuid))
696

697
        self._add_placeholder("image", placeholder)
698

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
715
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
716
717
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
718
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
719
        else:
720
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
721
722
723

        self._add_placeholder("audio", placeholder)

724
    def parse_image_pil(
725
        self, image_pil: Image.Image | None, uuid: str | None = None
726
    ) -> None:
727
        placeholder = self._tracker.add("image", (image_pil, uuid))
728
        self._add_placeholder("image", placeholder)
729

730
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
731
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
732

733
        placeholder = self._tracker.add("audio", (audio, uuid))
734
        self._add_placeholder("audio", placeholder)
735

736
    def parse_input_audio(
737
        self, input_audio: InputAudio | None, uuid: str | None = None
738
    ) -> None:
739
740
741
742
743
744
745
746
747
748
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
749

750
        return self.parse_audio(audio_url, uuid)
751

752
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
753
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
754

755
        placeholder = self._tracker.add("video", (video, uuid))
756
        self._add_placeholder("video", placeholder)
757

758
759
760
761
762
763

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
764
765
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
766
767
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
768
            media_io_kwargs=media_io_kwargs,
769
            allowed_local_media_path=tracker.allowed_local_media_path,
770
            allowed_media_domains=tracker.allowed_media_domains,
771
        )
772

773
774
    @property
    def model_config(self) -> ModelConfig:
775
        return self._tracker.model_config
776

777
778
779
780
781
782
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

783
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
784
        coro = self._image_with_uuid_async(image_url, uuid)
785

786
        placeholder = self._tracker.add("image", coro)
787
        self._add_placeholder("image", placeholder)
788

789
    def parse_image_embeds(
790
        self,
791
792
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
793
    ) -> None:
794
795
796
797
798
799
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

800
801
802
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
803
804
805
806
807
808

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
809
            future.set_result((embeds, uuid))
810
811

        if isinstance(image_embeds, str):
812
            embedding = self._connector.fetch_image_embedding(image_embeds)
813
            future.set_result((embedding, uuid))
814

815
        if image_embeds is None:
816
            future.set_result((None, uuid))
817

818
        placeholder = self._tracker.add("image_embeds", future)
819
        self._add_placeholder("image", placeholder)
820

821
822
823
824
825
826
827
828
829
830
831
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

832
833
834
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
835
836
837
838
839
840

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
841
            future.set_result((embeds, uuid))
842
843
844

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
845
            future.set_result((embedding, uuid))
846
847

        if audio_embeds is None:
848
            future.set_result((None, uuid))
849

850
        placeholder = self._tracker.add("audio_embeds", future)
851
852
        self._add_placeholder("audio", placeholder)

853
    def parse_image_pil(
854
855
856
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
857
    ) -> None:
858
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
859
        if image_pil:
860
            future.set_result((image_pil, uuid))
861
        else:
862
            future.set_result((None, uuid))
863

864
        placeholder = self._tracker.add("image", future)
865
        self._add_placeholder("image", placeholder)
866

867
868
869
870
871
872
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

873
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
874
        coro = self._audio_with_uuid_async(audio_url, uuid)
875

876
        placeholder = self._tracker.add("audio", coro)
877
        self._add_placeholder("audio", placeholder)
878

879
    def parse_input_audio(
880
        self, input_audio: InputAudio | None, uuid: str | None = None
881
    ) -> None:
882
883
884
885
886
887
888
889
890
891
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
892

893
        return self.parse_audio(audio_url, uuid)
894

895
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
896
        video = (
897
            await self._connector.fetch_video_async(video_url) if video_url else None
898
        )
899
900
901
902
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
903

904
        placeholder = self._tracker.add("video", coro)
905
        self._add_placeholder("video", placeholder)
906

907

908
def validate_chat_template(chat_template: Path | str | None):
909
910
911
912
913
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
914
        raise FileNotFoundError("the supplied chat template path doesn't exist")
915
916
917

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
918
919
920
921
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
922
923
924
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
925
            )
926

927
928
929
930
931
932
933
934
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

935
    else:
936
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
937
938


939
def _load_chat_template(
940
    chat_template: Path | str | None,
941
942
    *,
    is_literal: bool = False,
943
) -> str | None:
944
945
    if chat_template is None:
        return None
946
947
948

    if is_literal:
        if isinstance(chat_template, Path):
949
950
951
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
952

953
        return chat_template
954

955
    try:
956
        with open(chat_template) as f:
957
            return f.read()
958
    except OSError as e:
959
960
961
        if isinstance(chat_template, Path):
            raise

962
963
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
964
965
966
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
967
            )
968
969
970
971
972
973
974
975
976
977
978
979
980

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
981

982
983
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
984
985
986
987
988
989
990
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
991
    chat_template: Path | str | None,
992
993
    *,
    is_literal: bool = False,
994
) -> str | None:
995
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
996
997


998
999
1000
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1001
1002
1003
1004
1005
1006
1007
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1008
# TODO: Let user specify how to insert multimodal tokens into prompt
1009
# (similar to chat template)
1010
1011
1012
1013
1014
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1015
    """Combine multimodal prompts for a multimodal language model."""
1016

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1034
    # Look through the text prompt to check for missing placeholders
1035
    missing_placeholders: list[str] = []
1036
1037
1038
1039
1040
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1041
1042
1043
1044
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1045
1046
                "when manually placing image placeholders.",
                interleave_strings,
1047
1048
            )
            logger.debug("Input prompt: %s", text_prompt)
1049
1050
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1051
1052
                "actual multimodal data items."
            )
1053

1054
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1055

1056
1057
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1058
    return "\n".join(missing_placeholders + [text_prompt])
1059
1060


1061
1062
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1063
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1064
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1065
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1066
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1067
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1068
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1069
1070
1071
1072
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1073

1074
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1075
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1076

1077
# Define a mapping from part types to their corresponding parsing functions.
1078
MM_PARSER_MAP: dict[
1079
1080
1081
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1082
1083
1084
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1085
    "output_text": lambda part: _TextParser(part).get("text", None),
1086
1087
1088
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1089
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1090
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1091
1092
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1093
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1094
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1095
1096
1097
1098
}


def _parse_chat_message_content_mm_part(
1099
1100
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1101
    """
1102
    Parses a given multi-modal content part based on its type.
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1116
1117
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1118
    part_type = part.get("type", None)
1119
    uuid = part.get("uuid", None)
1120

1121
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1122
1123
1124
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1125
1126
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1127
            logger.warning(
1128
                "'image_url.detail' is currently not supported and will be ignored."
1129
            )
1130
1131
1132
1133

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1134
    # 'type' is required field by pydantic
1135
1136
    if part_type is None or uuid is not None:
        if "image_url" in part:
1137
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1138
1139
1140
1141
1142
1143
1144
1145
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1146
            image_params = cast(  # type: ignore
1147
1148
1149
1150
1151
1152
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1153
            image_params = cast(  # type: ignore
1154
1155
1156
1157
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1158
1159
1160
1161
1162
1163
1164
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1165
        if "audio_url" in part:
1166
1167
1168
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1169
1170
1171
1172
1173
1174
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1175
        if part.get("input_audio") is not None:
1176
            input_audio_params = cast(dict[str, str], part)
1177
            return "input_audio", input_audio_params
1178
        if "video_url" in part:
1179
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1180
1181
1182
1183
1184
1185
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1186
1187
1188
1189
1190
1191
1192
1193
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1194
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1195
1196
1197
    "text",
    "refusal",
)
1198

1199

1200
1201
1202
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1203
    mm_tracker: BaseMultiModalItemTracker,
1204
1205
    *,
    wrap_dicts: bool,
1206
    interleave_strings: bool,
1207
) -> list[ConversationMessage]:
1208
    content = list[_ContentPart]()
1209

1210
    mm_parser = mm_tracker.create_parser()
1211
1212

    for part in parts:
1213
        parse_res = _parse_chat_message_content_part(
1214
1215
1216
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1217
            interleave_strings=interleave_strings,
1218
        )
1219
1220
        if parse_res:
            content.append(parse_res)
1221

1222
    if wrap_dicts:
1223
        # Parsing wraps images and texts as interleaved dictionaries
1224
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1225
    texts = cast(list[str], content)
1226
1227
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1228
1229
1230
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1231
1232
1233
    else:
        text_prompt = "\n".join(texts)

1234
1235
1236
1237
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1238
1239
1240
1241
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1242
    interleave_strings: bool,
1243
) -> _ContentPart | None:
1244
1245
1246
1247
1248
1249
1250
1251
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1252
        return part
1253
1254
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1255
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1256
    # content is None, log a warning and skip
1257
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1258
        logger.warning(
1259
            "Skipping multimodal part '%s' (type: '%s') "
1260
1261
1262
1263
            "with empty / unparsable content.",
            part,
            part_type,
        )
1264
1265
        return None

1266
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1267
1268
        str_content = cast(str, content)
        if wrap_dicts:
1269
            return {"type": "text", "text": str_content}
1270
1271
        else:
            return str_content
1272

1273
1274
1275
1276
1277
1278
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1279
    modality = None
1280
    if part_type == "image_pil":
1281
        image_content = cast(Image.Image, content) if content is not None else None
1282
        mm_parser.parse_image_pil(image_content, uuid)
1283
        modality = "image"
1284
    elif part_type in ("image_url", "input_image"):
1285
        str_content = cast(str, content)
1286
        mm_parser.parse_image(str_content, uuid)
1287
1288
        modality = "image"
    elif part_type == "image_embeds":
1289
        content = cast(str | dict[str, str], content) if content is not None else None
1290
        mm_parser.parse_image_embeds(content, uuid)
1291
        modality = "image"
1292
1293
1294
1295
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1296
    elif part_type == "audio_url":
1297
        str_content = cast(str, content)
1298
        mm_parser.parse_audio(str_content, uuid)
1299
1300
        modality = "audio"
    elif part_type == "input_audio":
1301
        dict_content = cast(InputAudio, content)
1302
        mm_parser.parse_input_audio(dict_content, uuid)
1303
1304
        modality = "audio"
    elif part_type == "video_url":
1305
        str_content = cast(str, content)
1306
        mm_parser.parse_video(str_content, uuid)
1307
1308
1309
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1310

1311
1312
1313
    return (
        {"type": modality}
        if wrap_dicts
1314
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1315
    )
1316
1317


1318
1319
1320
1321
1322
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1323
def _parse_chat_message_content(
1324
1325
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1326
    content_format: ChatTemplateContentFormat,
1327
    interleave_strings: bool,
1328
) -> list[ConversationMessage]:
1329
1330
    role = message["role"]
    content = message.get("content")
1331
    reasoning = message.get("reasoning") or message.get("reasoning_content")
1332

1333
    if content is None:
1334
1335
        content = []
    elif isinstance(content, str):
1336
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1337
    result = _parse_chat_message_content_parts(
1338
1339
        role,
        content,  # type: ignore
1340
        mm_tracker,
1341
        wrap_dicts=(content_format == "openai"),
1342
        interleave_strings=interleave_strings,
1343
    )
1344

1345
    for result_msg in result:
1346
        if role == "assistant":
1347
1348
            parsed_msg = _AssistantParser(message)

1349
1350
1351
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1352
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1353
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1354
1355
1356
1357
1358
1359
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1360
1361
1362
1363
1364
1365
1366
1367
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1368
1369
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1370
1371
    return result

1372

1373
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1374
1375
1376
1377
1378
1379
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1391
1392
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1393
1394
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1395
1396
                else:
                    item["function"]["arguments"] = {}
1397
1398


1399
def parse_chat_messages(
1400
    messages: list[ChatCompletionMessageParam],
1401
    model_config: ModelConfig,
1402
    content_format: ChatTemplateContentFormat,
1403
1404
) -> tuple[
    list[ConversationMessage],
1405
1406
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1407
]:
1408
    conversation: list[ConversationMessage] = []
1409
    mm_tracker = MultiModalItemTracker(model_config)
1410
1411

    for msg in messages:
1412
1413
1414
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1415
            content_format,
1416
1417
1418
1419
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1420
            ),
1421
        )
1422

1423
        conversation.extend(sub_messages)
1424

1425
1426
    _postprocess_messages(conversation)

1427
1428
1429
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1430
1431


1432
async def parse_chat_messages_async(
1433
    messages: list[ChatCompletionMessageParam],
1434
    model_config: ModelConfig,
1435
    content_format: ChatTemplateContentFormat,
1436
1437
) -> tuple[
    list[ConversationMessage],
1438
    MultiModalDataDict | None,
1439
    MultiModalUUIDDict | None,
1440
]:
1441
    conversation: list[ConversationMessage] = []
1442
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1443
1444

    for msg in messages:
1445
1446
1447
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1448
            content_format,
1449
1450
1451
1452
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1453
            ),
1454
        )
1455
1456
1457

        conversation.extend(sub_messages)

1458
1459
    _postprocess_messages(conversation)

1460
1461
1462
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1463

1464

1465
1466
1467
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1468
1469
1470
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1471
1472
1473
    return idx


1474
1475
1476
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1477
1478
1479
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"