io_processor.py 8.29 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from collections.abc import Sequence
5
6
from typing import Any, Final

7
8
from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.config import VllmConfig
9
10
11
12
13
14
15
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateConfig,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
16
17
18
19
20
21
22
23
from vllm.inputs import EngineInput, SingletonPrompt
from vllm.renderers import BaseRenderer, TokenizeParams, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer

from ..scoring.typing import ScoringData
from ..typing import (
24
25
    OfflineInputsContext,
    OfflineOutputsContext,
26
27
28
29
    PoolingChatLikeRequest,
    PoolingCompletionLikeRequest,
    PoolingServeContext,
)
30
31
32


class PoolingIOProcessor:
33
34
35
36
37
38
    """Processor for handling preprocessing & postprocessing ops for pooling requests.

    This class manages both online (serving) and offline (batch) processing of pooling
    requests, handling chat and completion formats.
    """

39
40
    name: str

41
42
    def __init__(
        self,
43
        vllm_config: VllmConfig,
44
45
46
        renderer: BaseRenderer,
        chat_template_config: ChatTemplateConfig,
    ):
47
48
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
49
50
51
52
53
54
55
56
57
58
        self.renderer = renderer

        self.chat_template = chat_template_config.chat_template
        self.chat_template_content_format: Final = (
            chat_template_config.chat_template_content_format
        )
        self.trust_request_chat_template = (
            chat_template_config.trust_request_chat_template
        )

59
60
    #######################################
    # online APIs
61

62
63
64
    def create_pooling_params(self, request):
        return request.to_pooling_params()

65
66
67
    def pre_process_online(self, ctx: PoolingServeContext):
        request = ctx.request

68
        if isinstance(request, PoolingChatLikeRequest):
69
70
71
72
73
            self._validate_chat_template(
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
74
            _, engine_inputs = self._preprocess_chat_online(
75
76
77
78
79
80
81
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=None,
            )
        elif isinstance(request, PoolingCompletionLikeRequest):
82
            engine_inputs = self._preprocess_cmpl_online(
83
84
85
86
87
88
89
                request,
                prompt_input=request.input,
                prompt_embeds=None,
            )
        else:
            raise ValueError(f"Invalid {self.name} request type")

90
        ctx.engine_inputs = engine_inputs
91
92
93
94
95
96
97
98
99
100

    def post_process_online(
        self,
        ctx: PoolingServeContext,
    ):
        pass

    #######################################
    # offline APIs

101
    def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
102
103
104
105
106
        assert not isinstance(ctx.prompts, ScoringData) and not (
            isinstance(ctx.prompts, dict) and "data" in ctx.prompts
        )

        prompts_seq = prompt_to_seq(ctx.prompts)
107
108
109
        tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
            **(ctx.tokenization_kwargs or {})
        )
110
        return self._preprocess_cmpl_offline(prompts=prompts_seq, tok_params=tok_params)
111

112
113
    def post_process_offline(
        self,
114
        ctx: OfflineOutputsContext,
115
    ) -> list[PoolingRequestOutput]:
116
        return ctx.outputs
117

118
119
    #######################################
    # helpers
120

121
    def _preprocess_cmpl_online(
122
123
124
125
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
126
    ) -> list[EngineInput]:
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        renderer = self.renderer
        model_config = self.model_config

        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = request.build_tok_params(model_config)

        return renderer.render_cmpl(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

    def _preprocess_chat_online(
        self,
        request: RendererChatRequest,
        messages: list[ChatCompletionMessageParam],
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
        tool_dicts: list[dict[str, Any]] | None = None,
164
        tool_parser: type[ToolParser] | None = None,
165
    ) -> tuple[list[ConversationMessage], list[EngineInput]]:
166
167
168
169
170
171
172
173
174
175
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
            ),
        )

176
177
        mm_config = self.model_config.multimodal_config

178
179
180
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
181
182
183
184
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
        )
185

186
        (conversation,), (engine_input,) = renderer.render_chat(
187
188
189
190
191
192
193
194
195
196
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

197
        return conversation, [engine_input]
198

199
    def _preprocess_cmpl_offline(
200
201
        self,
        prompts: PromptType | Sequence[PromptType],
202
203
        tok_params: TokenizeParams,
        prompt_extras: dict[str, Any] | None = None,
204
    ) -> Sequence[EngineInput]:
205
206
207
208
209
        prompts = prompt_to_seq(prompts)
        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
210
                else parse_model_prompt(self.model_config, prompt)
211
212
213
214
            )
            for prompt in prompts
        ]

215
216
        return self.renderer.render_cmpl(
            parsed_prompts, tok_params, prompt_extras=prompt_extras
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        )

    def _validate_chat_template(
        self,
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
        trust_request_chat_template: bool,
    ):
        if not trust_request_chat_template and (
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
            raise ValueError(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
                "Refused request with untrusted chat template."
            )
        return None
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

    def _params_to_seq(
        self,
        params: PoolingParams | Sequence[PoolingParams],
        num_requests: int,
    ) -> Sequence[PoolingParams]:
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and params ({len(params)}) must be the same."
                )

            return params

        return [params] * num_requests