mistral.py 4.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor
from typing import Any

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
13
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
14
15
16
17
18
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async

19
from .params import ChatParams
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from .protocol import RendererLike

logger = init_logger(__name__)


def safe_apply_chat_template(
    tokenizer: MistralTokenizer,
    messages: list[ChatCompletionMessageParam],
    **kwargs,
) -> str | list[int]:
    from mistral_common.exceptions import MistralCommonException

    try:
        return tokenizer.apply_chat_template(messages, **kwargs)
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
    # properly caught in the preprocessing_input step
    except (AssertionError, MistralCommonException) as e:
        raise ValueError(str(e)) from e

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
            "An error occurred in `mistral_common` while applying chat template"
        )
        raise ValueError(str(e)) from e


class MistralRenderer(RendererLike):
    @classmethod
    def from_config(
        cls,
        config: ModelConfig,
        tokenizer_kwargs: dict[str, Any],
    ) -> "RendererLike":
        return cls(config, tokenizer_kwargs)

    def __init__(
        self,
        config: ModelConfig,
        tokenizer_kwargs: dict[str, Any],
    ) -> None:
        super().__init__()

        self.config = config

        if config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = cached_get_tokenizer(
                tokenizer_cls=MistralTokenizer,
                **tokenizer_kwargs,
            )

        self._tokenizer = tokenizer

        self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
        self._apply_chat_template_async = make_async(
            safe_apply_chat_template, executor=self._apply_chat_template_executor
        )

    @property
    def tokenizer(self) -> MistralTokenizer | None:
        return self._tokenizer

    def get_tokenizer(self) -> MistralTokenizer:
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
99
100
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
101
102
103
104
105
106
107
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            self.config,
            content_format="string",
        )

108
109
110
111
        prompt_raw = safe_apply_chat_template(
            tokenizer,
            messages,
            **params.get_apply_chat_template_kwargs(),
112
        )
113
114

        prompt = self.render_completion(prompt_raw)
115
116
117
118
119
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

120
        return conversation, prompt
121
122
123
124

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
125
126
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
127
128
129
130
131
132
133
134
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            self.config,
            content_format="string",
        )

        prompt_raw = await self._apply_chat_template_async(
135
136
137
            tokenizer,
            messages,
            **params.get_apply_chat_template_kwargs(),
138
139
        )

140
        prompt = self.render_completion(prompt_raw)
141
142
143
144
145
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

146
        return conversation, prompt