mistral.py 3.91 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor

5
from vllm.config import VllmConfig
6
7
8
9
10
11
12
13
14
15
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async

16
from .base import BaseRenderer
17
18
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
19
from .params import ChatParams
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

logger = init_logger(__name__)


def safe_apply_chat_template(
    tokenizer: MistralTokenizer,
    messages: list[ChatCompletionMessageParam],
    **kwargs,
) -> str | list[int]:
    from mistral_common.exceptions import MistralCommonException

    try:
        return tokenizer.apply_chat_template(messages, **kwargs)
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
    # properly caught in the preprocessing_input step
    except (AssertionError, MistralCommonException) as e:
        raise ValueError(str(e)) from e

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
            "An error occurred in `mistral_common` while applying chat template"
        )
        raise ValueError(str(e)) from e


51
52
53
54
55
56
57
class MistralRenderer(BaseRenderer[MistralTokenizer]):
    def __init__(
        self,
        config: VllmConfig,
        tokenizer: MistralTokenizer | None,
    ) -> None:
        super().__init__(config, tokenizer)
58
59
60
61
62
63
64
65
66

        self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
        self._apply_chat_template_async = make_async(
            safe_apply_chat_template, executor=self._apply_chat_template_executor
        )

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
67
        params: ChatParams,
68
    ) -> tuple[list[ConversationMessage], DictPrompt]:
69
70
71
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
72
            self.model_config,
73
            content_format="string",
74
            media_io_kwargs=params.media_io_kwargs,
75
            mm_processor_kwargs=params.mm_processor_kwargs,
76
77
        )

78
79
80
81
        prompt_raw = safe_apply_chat_template(
            tokenizer,
            messages,
            **params.get_apply_chat_template_kwargs(),
82
        )
83

84
        prompt = parse_dec_only_prompt(prompt_raw)
85
86
87
88
89
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

90
        return conversation, prompt
91
92
93
94

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
95
        params: ChatParams,
96
    ) -> tuple[list[ConversationMessage], DictPrompt]:
97
98
99
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
100
            self.model_config,
101
            content_format="string",
102
            media_io_kwargs=params.media_io_kwargs,
103
            mm_processor_kwargs=params.mm_processor_kwargs,
104
105
106
        )

        prompt_raw = await self._apply_chat_template_async(
107
108
109
            tokenizer,
            messages,
            **params.get_apply_chat_template_kwargs(),
110
111
        )

112
        prompt = parse_dec_only_prompt(prompt_raw)
113
114
115
116
117
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

118
        return conversation, prompt