io_processor.py 7.95 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from collections.abc import Sequence
5
6
7
8
9
10
11
12
13
14
15
from typing import Any, Final

from vllm import PoolingRequestOutput, PromptType
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateConfig,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
16
17
18
19
20
from vllm.entrypoints.pooling.typing import (
    PoolingChatLikeRequest,
    PoolingCompletionLikeRequest,
    PoolingServeContext,
)
21
from vllm.inputs import EngineInput, SingletonPrompt
22
23
24
25
26
27
28
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer


class PoolingIOProcessor:
29
30
    name: str

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    def __init__(
        self,
        model_config: ModelConfig,
        renderer: BaseRenderer,
        chat_template_config: ChatTemplateConfig,
    ):
        self.model_config = model_config
        self.renderer = renderer

        self.chat_template = chat_template_config.chat_template
        self.chat_template_content_format: Final = (
            chat_template_config.chat_template_content_format
        )
        self.trust_request_chat_template = (
            chat_template_config.trust_request_chat_template
        )

48
49
    def create_pooling_params(self, request):
        return request.to_pooling_params()
50

51
52
    #######################################
    # online APIs
53

54
55
56
57
58
59
60
61
62
    def pre_process_online(self, ctx: PoolingServeContext):
        request = ctx.request

        if isinstance(ctx.request, PoolingChatLikeRequest):
            self._validate_chat_template(
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
63
            _, engine_inputs = self._preprocess_chat_online(
64
65
66
67
68
69
70
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=None,
            )
        elif isinstance(request, PoolingCompletionLikeRequest):
71
            engine_inputs = self._preprocess_completion_online(
72
73
74
75
76
77
78
                request,
                prompt_input=request.input,
                prompt_embeds=None,
            )
        else:
            raise ValueError(f"Invalid {self.name} request type")

79
        ctx.engine_inputs = engine_inputs
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    async def pre_process_online_async(self, ctx: PoolingServeContext):
        self.pre_process_online(ctx)

    def post_process_online(
        self,
        ctx: PoolingServeContext,
    ):
        pass

    async def post_process_online_async(
        self,
        ctx: PoolingServeContext,
    ):
        self.post_process_online(ctx)

    #######################################
    # offline APIs

    def pre_process_offline(
        self,
        prompts: PromptType | Sequence[PromptType],
        tokenization_kwargs: dict[str, Any] | None = None,
103
    ) -> Sequence[EngineInput]:
104
105
106
        return self._preprocess_completion_offline(
            prompts=prompts, tokenization_kwargs=tokenization_kwargs
        )
107
108
109
110

    async def pre_process_offline_async(self, *args, **kwargs):
        return self.pre_process_offline(*args, **kwargs)

111
112
113
    def post_process_offline(
        self,
        outputs: list[PoolingRequestOutput],
114
115
116
    ) -> list[PoolingRequestOutput]:
        return outputs

117
118
119
    async def post_process_offline_async(
        self,
        outputs: list[PoolingRequestOutput],
120
    ) -> list[PoolingRequestOutput]:
121
        return self.post_process_offline(outputs)
122

123
124
    #######################################
    # helpers
125
126
127
128
129
130

    def _preprocess_completion_online(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
131
    ) -> list[EngineInput]:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        renderer = self.renderer
        model_config = self.model_config

        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = request.build_tok_params(model_config)

        return renderer.render_cmpl(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

    def _preprocess_chat_online(
        self,
        request: RendererChatRequest,
        messages: list[ChatCompletionMessageParam],
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
        tool_dicts: list[dict[str, Any]] | None = None,
169
        tool_parser: type[ToolParser] | None = None,
170
    ) -> tuple[list[ConversationMessage], list[EngineInput]]:
171
172
173
174
175
176
177
178
179
180
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
            ),
        )

181
182
        mm_config = self.model_config.multimodal_config

183
184
185
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
186
187
188
189
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
        )
190

191
        (conversation,), (engine_input,) = renderer.render_chat(
192
193
194
195
196
197
198
199
200
201
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

202
        return conversation, [engine_input]
203
204
205
206
207

    def _preprocess_completion_offline(
        self,
        prompts: PromptType | Sequence[PromptType],
        tokenization_kwargs: dict[str, Any] | None = None,
208
    ) -> Sequence[EngineInput]:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        renderer = self.renderer
        model_config = self.model_config

        prompts = prompt_to_seq(prompts)

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )

        return renderer.render_cmpl(
            parsed_prompts,
            tok_params,
        )

    def _validate_chat_template(
        self,
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
        trust_request_chat_template: bool,
    ):
        if not trust_request_chat_template and (
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
            raise ValueError(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
                "Refused request with untrusted chat template."
            )
        return None