chat_utils.py 49.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from pathlib import Path
12
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
13

14
from openai.types.chat import (
15
16
17
18
19
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
20
    ChatCompletionFunctionToolParam,
21
22
23
24
25
26
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
27
from openai.types.chat import (
28
29
30
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
31
from openai.types.responses import ResponseInputImageParam
32
from openai_harmony import Message as OpenAIHarmonyMessage
33
34
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
35

36
# pydantic needs the TypedDict from typing_extensions
37
from typing_extensions import Required, TypedDict
38

39
from vllm import envs
40
from vllm.config import ModelConfig
41
from vllm.logger import init_logger
42
from vllm.model_executor.models import SupportsMultiModal
43
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
44
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
45
from vllm.utils import random_uuid
46
47
48
49
50
51
52
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
else:
    torch = LazyLoader("torch", globals(), "torch")
53
54
55

logger = init_logger(__name__)

56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


74
75
76
77
78
79
80
81
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


82
83
84
85
86
87
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


103
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
104
    image_embeds: str | dict[str, str] | None
105
106
107
108
109
110
111
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
112
    uuid: str | None
113
114
115
116
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
117
118


119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


135
136
137
138
139
140
141
142
143
144
145
146
147
148
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


149
150
151
152
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
153

154
155
156
157
158
159
160
161
162
163
164
165
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
166

167
168
    image_pil: PILImage | None
    uuid: str | None
169
170
171
172
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
173
174


175
176
177
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
178

179
180
181
182
183
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
184

185
186
    image_url: str | None
    uuid: str | None
187
188
189
190
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
191
192
193
194


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
195

196
197
198
199
200
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
201

202
    audio_url: str | None
203
204


205
206
207
208
209
210
211
212
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
213

214
215
    video_url: str | None
    uuid: str | None
216
217
218
219
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
220
221


Julien Denize's avatar
Julien Denize committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


243
244
245
246
247
248
249
250
251
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
252
    | ChatCompletionContentPartAudioEmbedsParam
253
254
255
256
257
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
258
259
260
261


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
262

263
264
265
    role: Required[str]
    """The role of the message's author."""

266
    content: str | list[ChatCompletionContentPartParam]
267
268
269
270
271
272
273
274
275
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

276
    tool_call_id: str | None
277
278
    """Tool call that this message is responding to."""

279
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
280
281
    """The tool calls generated by the model, such as function calls."""

282
283
284
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

285
286
287
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

288

289
290
291
292
293
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
294
295


296
# TODO: Make fields ReadOnly once mypy supports it
297
298
299
300
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

301
    content: str | None | list[dict[str, str]]
302
303
    """The contents of the message"""

304
    tool_call_id: str | None
305
306
    """Tool call that this message is responding to."""

307
    name: str | None
308
309
    """The name of the function to call"""

310
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
311
    """The tool calls generated by the model, such as function calls."""
312

313
314
315
316
317
318
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

319
320
321
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

322

323
324
325
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

326
327
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
328

329

330
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
331
332
333
_T = TypeVar("_T")


334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def _extract_embeds(tensors: list[torch.Tensor]):
    if len(tensors) == 0:
        return tensors

    if len(tensors) == 1:
        tensors[0]._is_single_item = True  # type: ignore
        return tensors[0]  # To keep backwards compatibility for single item input

    first_shape = tensors[0].shape
    if all(t.shape == first_shape for t in tensors):
        return torch.stack(tensors)

    return tensors


def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
    embeds_key = f"{modality}_embeds"
    embeds = items_by_modality[embeds_key]

    if len(embeds) == 0:
        return embeds
    if is_list_of(embeds, torch.Tensor):
        return _extract_embeds(embeds)
    if is_list_of(embeds, dict):
        if not embeds:
            return {}

        first_keys = set(embeds[0].keys())
        if any(set(item.keys()) != first_keys for item in embeds[1:]):
            raise ValueError(
                "All dictionaries in the list of embeddings must have the same keys."
            )

        return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}

    return embeds


372
class BaseMultiModalItemTracker(ABC, Generic[_T]):
373
374
375
376
377
378
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

379
    def __init__(self, model_config: ModelConfig):
380
381
        super().__init__()

382
        self._model_config = model_config
383

384
385
        self._items_by_modality = defaultdict[str, list[_T | None]](list)
        self._uuids_by_modality = defaultdict[str, list[str | None]](list)
386

387
    @property
388
389
    def model_config(self) -> ModelConfig:
        return self._model_config
390

391
    @cached_property
392
    def model_cls(self) -> type[SupportsMultiModal]:
393
        from vllm.model_executor.model_loader import get_model_cls
394

395
        model_cls = get_model_cls(self.model_config)
396
        return cast(type[SupportsMultiModal], model_cls)
397

398
399
    @property
    def allowed_local_media_path(self):
400
        return self._model_config.allowed_local_media_path
401

402
403
    @property
    def allowed_media_domains(self):
404
        return self._model_config.allowed_media_domains
405

406
407
408
409
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

410
411
    @cached_property
    def mm_processor(self):
412
        return self.mm_registry.create_processor(self.model_config)
413

414
    def add(
415
416
        self,
        modality: ModalityStr,
417
418
419
        item: _T | None,
        uuid: str | None = None,
    ) -> str | None:
420
421
422
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
423
424

        An optional uuid can be added which serves as a unique identifier of the
425
        media.
426
        """
427
        input_modality = modality.replace("_embeds", "")
428
        num_items = len(self._items_by_modality[modality]) + 1
429

430
        self.mm_processor.validate_num_items(input_modality, num_items)
431

432
        self._items_by_modality[modality].append(item)
433
        self._uuids_by_modality[modality].append(uuid)
434

435
        return self.model_cls.get_placeholder_str(modality, num_items)
436

437
    def all_mm_uuids(self) -> MultiModalUUIDDict | None:
438
439
        if not self._items_by_modality:
            return None
440

441
442
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
443
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
444
445
        if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
            raise ValueError("Mixing raw audio and embedding inputs is not allowed")
446

447
        mm_uuids = {}
448
449
450
451
        if "image_embeds" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
452
453
        if "audio_embeds" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
454
455
456
457
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
458

459
460
        return mm_uuids

461
462
463
464
465
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


466
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
467
    def all_mm_data(self) -> MultiModalDataDict | None:
468
469
        if not self._items_by_modality:
            return None
470

471
472
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
473
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
474
475
        if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
            raise ValueError("Mixing raw audio and embedding inputs is not allowed")
476

477
        mm_inputs = {}
478
        if "image_embeds" in items_by_modality:
479
            mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
480
        if "image" in items_by_modality:
481
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
482
        if "audio_embeds" in items_by_modality:
483
            mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
484
        if "audio" in items_by_modality:
485
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
486
        if "video" in items_by_modality:
487
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
488

489
        return mm_inputs
490
491
492
493
494

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


495
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
496
    async def all_mm_data(self) -> MultiModalDataDict | None:
497
498
        if not self._items_by_modality:
            return None
499

500
501
502
503
504
505
506
507
        coros_by_modality = {
            modality: [item or asyncio.sleep(0) for item in items]
            for modality, items in self._items_by_modality.items()
        }
        items_by_modality: dict[str, list[object | None]] = {
            modality: await asyncio.gather(*coros)
            for modality, coros in coros_by_modality.items()
        }
508
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
509
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
510
511
        if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
            raise ValueError("Mixing raw audio and embedding inputs is not allowed")
512

513
        mm_inputs = {}
514
        if "image_embeds" in items_by_modality:
515
            mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
516
        if "image" in items_by_modality:
517
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
518
        if "audio_embeds" in items_by_modality:
519
            mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
520
        if "audio" in items_by_modality:
521
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
522
        if "video" in items_by_modality:
523
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
524

525
        return mm_inputs
526
527
528
529
530
531
532
533
534

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

535
        # stores model placeholders list with corresponding
536
537
538
539
540
541
542
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

543
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
544
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
545
        if placeholder:
546
            self._placeholder_storage[mod_placeholder].append(placeholder)
547

548
549
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
550
551

    @abstractmethod
552
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
553
554
        raise NotImplementedError

555
    @abstractmethod
556
    def parse_image_embeds(
557
        self,
558
559
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
560
    ) -> None:
561
562
        raise NotImplementedError

563
    @abstractmethod
564
    def parse_image_pil(
565
        self, image_pil: Image.Image | None, uuid: str | None = None
566
    ) -> None:
567
568
        raise NotImplementedError

569
    @abstractmethod
570
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
571
572
        raise NotImplementedError

573
    @abstractmethod
574
    def parse_input_audio(
575
        self, input_audio: InputAudio | None, uuid: str | None = None
576
    ) -> None:
577
578
        raise NotImplementedError

579
580
581
582
583
584
585
586
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

587
    @abstractmethod
588
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
589
590
        raise NotImplementedError

591
592
593
594
595
596

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
597
598
599
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

600
601
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
602
            media_io_kwargs=media_io_kwargs,
603
            allowed_local_media_path=tracker.allowed_local_media_path,
604
            allowed_media_domains=tracker.allowed_media_domains,
605
606
        )

607
608
    @property
    def model_config(self) -> ModelConfig:
609
        return self._tracker.model_config
610

611
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
612
        image = self._connector.fetch_image(image_url) if image_url else None
613

614
        placeholder = self._tracker.add("image", image, uuid)
615
        self._add_placeholder("image", placeholder)
616

617
    def parse_image_embeds(
618
        self,
619
620
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
621
    ) -> None:
622
623
624
625
626
627
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

628
629
630
631
632
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
633
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
634
635
636

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
637
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
638

639
640
641
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

642
        self._add_placeholder("image", placeholder)
643

644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
            placeholder = self._tracker.add("audio_embeds", embeds, uuid)
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
            placeholder = self._tracker.add("audio_embeds", embedding, uuid)
        else:
            placeholder = self._tracker.add("audio_embeds", None, uuid)

        self._add_placeholder("audio", placeholder)

669
    def parse_image_pil(
670
        self, image_pil: Image.Image | None, uuid: str | None = None
671
672
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
673
        self._add_placeholder("image", placeholder)
674

675
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
676
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
677

678
        placeholder = self._tracker.add("audio", audio, uuid)
679
        self._add_placeholder("audio", placeholder)
680

681
    def parse_input_audio(
682
        self, input_audio: InputAudio | None, uuid: str | None = None
683
    ) -> None:
684
685
686
687
688
689
690
691
692
693
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
694

695
        return self.parse_audio(audio_url, uuid)
696

697
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
698
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
699

700
        placeholder = self._tracker.add("video", video, uuid)
701
        self._add_placeholder("video", placeholder)
702

703
704
705
706
707
708

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
709
710
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
711
712
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
713
            media_io_kwargs=media_io_kwargs,
714
            allowed_local_media_path=tracker.allowed_local_media_path,
715
            allowed_media_domains=tracker.allowed_media_domains,
716
        )
717

718
719
    @property
    def model_config(self) -> ModelConfig:
720
        return self._tracker.model_config
721

722
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
723
        image_coro = self._connector.fetch_image_async(image_url) if image_url else None
724

725
        placeholder = self._tracker.add("image", image_coro, uuid)
726
        self._add_placeholder("image", placeholder)
727

728
    def parse_image_embeds(
729
        self,
730
731
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
732
    ) -> None:
733
734
735
736
737
738
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

739
        future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
740
741
742
743
744
745
746
747
748

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
749
            embedding = self._connector.fetch_image_embedding(image_embeds)
750
751
            future.set_result(embedding)

752
753
754
        if image_embeds is None:
            future.set_result(None)

755
        placeholder = self._tracker.add("image_embeds", future, uuid)
756
        self._add_placeholder("image", placeholder)
757

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        logger.info(
            "🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
            "is_str=%s, is_none=%s",
            type(audio_embeds).__name__,
            uuid,
            isinstance(audio_embeds, dict),
            isinstance(audio_embeds, str),
            audio_embeds is None,
        )

        future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()

        if isinstance(audio_embeds, dict):
            logger.info(
                "🎵 Processing dict audio_embeds with %d entries",
                len(audio_embeds),
            )
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
            future.set_result(embeds)
            logger.info(
                "🎵 Successfully loaded %d audio embeddings from dict",
                len(embeds),
            )

        if isinstance(audio_embeds, str):
            base64_size = len(audio_embeds)
            logger.info(
                "🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
                base64_size,
                base64_size / 1024,
            )
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
            future.set_result(embedding)
            logger.info(
                "🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
                embedding.shape,
                embedding.dtype,
            )

        if audio_embeds is None:
            logger.info("🎵 Audio embeds is None (UUID-only reference)")
            future.set_result(None)

        placeholder = self._tracker.add("audio_embeds", future, uuid)
        self._add_placeholder("audio", placeholder)
        logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)

819
    def parse_image_pil(
820
        self, image_pil: Image.Image | None, uuid: str | None = None
821
    ) -> None:
822
        future: asyncio.Future[Image.Image | None] = asyncio.Future()
823
824
825
826
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
827

828
        placeholder = self._tracker.add("image", future, uuid)
829
        self._add_placeholder("image", placeholder)
830

831
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
832
        audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
833

834
        placeholder = self._tracker.add("audio", audio_coro, uuid)
835
        self._add_placeholder("audio", placeholder)
836

837
    def parse_input_audio(
838
        self, input_audio: InputAudio | None, uuid: str | None = None
839
    ) -> None:
840
841
842
843
844
845
846
847
848
849
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
850

851
        return self.parse_audio(audio_url, uuid)
852

853
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
854
855
856
857
858
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
859

860
        placeholder = self._tracker.add("video", video, uuid)
861
        self._add_placeholder("video", placeholder)
862

863

864
def validate_chat_template(chat_template: Path | str | None):
865
866
867
868
869
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
870
        raise FileNotFoundError("the supplied chat template path doesn't exist")
871
872
873

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
874
875
876
877
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
878
879
880
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
881
            )
882

883
884
885
886
887
888
889
890
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

891
    else:
892
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
893
894


895
def _load_chat_template(
896
    chat_template: Path | str | None,
897
898
    *,
    is_literal: bool = False,
899
) -> str | None:
900
901
    if chat_template is None:
        return None
902
903
904

    if is_literal:
        if isinstance(chat_template, Path):
905
906
907
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
908

909
        return chat_template
910

911
    try:
912
        with open(chat_template) as f:
913
            return f.read()
914
    except OSError as e:
915
916
917
        if isinstance(chat_template, Path):
            raise

918
919
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
920
921
922
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
923
            )
924
925
926
927
928
929
930
931
932
933
934
935
936

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
937

938
939
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
940
941
942
943
944
945
946
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
947
    chat_template: Path | str | None,
948
949
    *,
    is_literal: bool = False,
950
) -> str | None:
951
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
952
953


954
955
956
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
957
958
959
960
961
962
963
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


964
# TODO: Let user specify how to insert multimodal tokens into prompt
965
# (similar to chat template)
966
967
968
969
970
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
971
    """Combine multimodal prompts for a multimodal language model."""
972

973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

990
    # Look through the text prompt to check for missing placeholders
991
    missing_placeholders: list[str] = []
992
993
994
995
996
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
997
998
999
1000
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1001
1002
                "when manually placing image placeholders.",
                interleave_strings,
1003
1004
            )
            logger.debug("Input prompt: %s", text_prompt)
1005
1006
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1007
1008
                "actual multimodal data items."
            )
1009

1010
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1011

1012
1013
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1014
    return "\n".join(missing_placeholders + [text_prompt])
1015
1016


1017
1018
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1019
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1020
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1021
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1022
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1023
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1024
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1025
1026
1027
1028
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1029

1030
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1031
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1032

1033
# Define a mapping from part types to their corresponding parsing functions.
1034
MM_PARSER_MAP: dict[
1035
1036
1037
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1038
1039
1040
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1041
    "output_text": lambda part: _TextParser(part).get("text", None),
1042
1043
1044
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1045
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1046
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1047
1048
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1049
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1050
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1051
1052
1053
1054
}


def _parse_chat_message_content_mm_part(
1055
1056
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1057
    """
1058
    Parses a given multi-modal content part based on its type.
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1072
1073
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1074
    part_type = part.get("type", None)
1075
    uuid = part.get("uuid", None)
1076

1077
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1078
1079
1080
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1081
1082
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1083
            logger.warning(
1084
                "'image_url.detail' is currently not supported and will be ignored."
1085
            )
1086
1087
1088
1089

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1090
    # 'type' is required field by pydantic
1091
1092
    if part_type is None or uuid is not None:
        if "image_url" in part:
1093
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1094
1095
1096
1097
1098
1099
1100
1101
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1102
            image_params = cast(  # type: ignore
1103
1104
1105
1106
1107
1108
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1109
            image_params = cast(  # type: ignore
1110
1111
1112
1113
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1114
1115
1116
1117
1118
1119
1120
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1121
        if "audio_url" in part:
1122
1123
1124
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1125
1126
1127
1128
1129
1130
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1131
        if part.get("input_audio") is not None:
1132
            input_audio_params = cast(dict[str, str], part)
1133
            return "input_audio", input_audio_params
1134
        if "video_url" in part:
1135
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1136
1137
1138
1139
1140
1141
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1142
1143
1144
1145
1146
1147
1148
1149
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1150
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1151
1152
1153
    "text",
    "refusal",
)
1154

1155

1156
1157
1158
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1159
    mm_tracker: BaseMultiModalItemTracker,
1160
1161
    *,
    wrap_dicts: bool,
1162
    interleave_strings: bool,
1163
) -> list[ConversationMessage]:
1164
    content = list[_ContentPart]()
1165

1166
    mm_parser = mm_tracker.create_parser()
1167
1168

    for part in parts:
1169
        parse_res = _parse_chat_message_content_part(
1170
1171
1172
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1173
            interleave_strings=interleave_strings,
1174
        )
1175
1176
        if parse_res:
            content.append(parse_res)
1177

1178
    if wrap_dicts:
1179
        # Parsing wraps images and texts as interleaved dictionaries
1180
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1181
    texts = cast(list[str], content)
1182
1183
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1184
1185
1186
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1187
1188
1189
    else:
        text_prompt = "\n".join(texts)

1190
1191
1192
1193
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1194
1195
1196
1197
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1198
    interleave_strings: bool,
1199
) -> _ContentPart | None:
1200
1201
1202
1203
1204
1205
1206
1207
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1208
        return part
1209
1210
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1211
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1212
    # content is None, log a warning and skip
1213
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1214
        logger.warning(
1215
            "Skipping multimodal part '%s' (type: '%s') "
1216
1217
1218
1219
            "with empty / unparsable content.",
            part,
            part_type,
        )
1220
1221
        return None

1222
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1223
1224
        str_content = cast(str, content)
        if wrap_dicts:
1225
            return {"type": "text", "text": str_content}
1226
1227
        else:
            return str_content
1228

1229
1230
1231
1232
1233
1234
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1235
    modality = None
1236
    if part_type == "image_pil":
1237
        image_content = cast(Image.Image, content) if content is not None else None
1238
        mm_parser.parse_image_pil(image_content, uuid)
1239
        modality = "image"
1240
    elif part_type in ("image_url", "input_image"):
1241
        str_content = cast(str, content)
1242
        mm_parser.parse_image(str_content, uuid)
1243
1244
        modality = "image"
    elif part_type == "image_embeds":
1245
        content = cast(str | dict[str, str], content) if content is not None else None
1246
        mm_parser.parse_image_embeds(content, uuid)
1247
        modality = "image"
1248
1249
1250
1251
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1252
    elif part_type == "audio_url":
1253
        str_content = cast(str, content)
1254
        mm_parser.parse_audio(str_content, uuid)
1255
1256
        modality = "audio"
    elif part_type == "input_audio":
1257
        dict_content = cast(InputAudio, content)
1258
        mm_parser.parse_input_audio(dict_content, uuid)
1259
1260
        modality = "audio"
    elif part_type == "video_url":
1261
        str_content = cast(str, content)
1262
        mm_parser.parse_video(str_content, uuid)
1263
1264
1265
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1266

1267
1268
1269
    return (
        {"type": modality}
        if wrap_dicts
1270
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1271
    )
1272
1273


1274
1275
1276
1277
1278
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1279
def _parse_chat_message_content(
1280
1281
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1282
    content_format: ChatTemplateContentFormat,
1283
    interleave_strings: bool,
1284
) -> list[ConversationMessage]:
1285
1286
    role = message["role"]
    content = message.get("content")
1287
    reasoning = message.get("reasoning") or message.get("reasoning_content")
1288

1289
    if content is None:
1290
1291
        content = []
    elif isinstance(content, str):
1292
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1293
    result = _parse_chat_message_content_parts(
1294
1295
        role,
        content,  # type: ignore
1296
        mm_tracker,
1297
        wrap_dicts=(content_format == "openai"),
1298
        interleave_strings=interleave_strings,
1299
    )
1300

1301
    for result_msg in result:
1302
        if role == "assistant":
1303
1304
            parsed_msg = _AssistantParser(message)

1305
1306
1307
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1308
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1309
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1310
1311
1312
1313
1314
1315
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1316
1317
1318
1319
1320
1321
1322
1323
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1324
1325
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1326
1327
    return result

1328

1329
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1330
1331
1332
1333
1334
1335
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1347
1348
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1349
1350
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1351
1352
                else:
                    item["function"]["arguments"] = {}
1353
1354


1355
def parse_chat_messages(
1356
    messages: list[ChatCompletionMessageParam],
1357
    model_config: ModelConfig,
1358
    content_format: ChatTemplateContentFormat,
1359
1360
) -> tuple[
    list[ConversationMessage],
1361
1362
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1363
]:
1364
    conversation: list[ConversationMessage] = []
1365
    mm_tracker = MultiModalItemTracker(model_config)
1366
1367

    for msg in messages:
1368
1369
1370
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1371
            content_format,
1372
1373
1374
1375
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1376
            ),
1377
        )
1378

1379
        conversation.extend(sub_messages)
1380

1381
1382
    _postprocess_messages(conversation)

1383
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1384
1385


1386
async def parse_chat_messages_async(
1387
    messages: list[ChatCompletionMessageParam],
1388
    model_config: ModelConfig,
1389
    content_format: ChatTemplateContentFormat,
1390
1391
) -> tuple[
    list[ConversationMessage],
1392
    MultiModalDataDict | None,
1393
    MultiModalUUIDDict | None,
1394
]:
1395
    conversation: list[ConversationMessage] = []
1396
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1397
1398

    for msg in messages:
1399
1400
1401
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1402
            content_format,
1403
1404
1405
1406
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1407
            ),
1408
        )
1409
1410
1411

        conversation.extend(sub_messages)

1412
1413
    _postprocess_messages(conversation)

1414
    return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1415

1416

1417
1418
1419
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1420
1421
1422
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1423
1424
1425
    return idx


1426
1427
1428
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1429
1430
1431
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"