preprocess.py 9.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any, overload
6
7
8

from typing_extensions import assert_never

9
from vllm.config import VllmConfig
10
from vllm.inputs import build_enc_dec_input
11
from vllm.logger import init_logger
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
from vllm.renderers import BaseRenderer, renderer_from_config
14
15
16
17
18
19
20
21
from vllm.renderers.inputs import (
    DecoderDictPrompt,
    DecoderOnlyDictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDictPrompt,
    SingletonDictPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
22
from vllm.tokenizers import TokenizerLike
23

24
25
26
27
28
29
30
31
32
33
34
35
36
from .engine import (
    DecoderEngineInput,
    DecoderOnlyEngineInput,
    EmbedsInput,
    EncoderDecoderInput,
    EncoderInput,
    EngineInput,
    MultiModalInput,
    SingletonInput,
    TokensInput,
    tokens_input,
)
from .llm import (
37
    EmbedsPrompt,
38
39
    MultiModalDataDict,
    MultiModalUUIDDict,
40
41
42
43
    PromptType,
    TextPrompt,
    TokensPrompt,
)
44
45
46
47
48
49
50

logger = init_logger(__name__)


class InputPreprocessor:
    def __init__(
        self,
51
        vllm_config: VllmConfig,
52
        renderer: BaseRenderer | None = None,
53
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
54
55
56
    ) -> None:
        super().__init__()

57
58
        self.model_config = vllm_config.model_config
        self.renderer = renderer or renderer_from_config(vllm_config)
59
        self.mm_registry = mm_registry
60

61
62
63
    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer
64

65
66
    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()
67
68
69
70

    def _tokenize_prompt(
        self,
        prompt: str,
71
        tokenization_kwargs: dict[str, Any] | None = None,
72
    ) -> list[int]:
73
74
75
76
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
77
        renderer = self.renderer
78

79
80
81
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
82

83
        tok_prompt = renderer._tokenize_singleton_prompt(
84
85
86
            TextPrompt(prompt=prompt),
            tok_params,
        )
87

88
        return tok_prompt["prompt_token_ids"]
89

90
91
    def _process_multimodal(
        self,
92
        prompt: str | list[int],
93
        mm_data: MultiModalDataDict,
94
        mm_processor_kwargs: Mapping[str, object] | None = None,
95
        tokenization_kwargs: dict[str, Any] | None = None,
96
        *,
97
        mm_uuids: MultiModalUUIDDict | None = None,
98
    ) -> MultiModalInput:
99
100
101
102
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
103
        return self.renderer._process_multimodal(
104
            prompt,
105
            mm_data,
106
            mm_uuids=mm_uuids,
107
            mm_processor_kwargs=mm_processor_kwargs,
108
109
            tokenization_kwargs=tokenization_kwargs,
        )
110

111
112
113
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
114
    ) -> EmbedsInput:
115
        return self.renderer._process_embeds(parsed_content)
116

117
    def _truncate_inputs(
118
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
119
    ) -> list[int]:
120
        renderer = self.renderer
121

122
123
124
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
125

126
        tok_prompt = renderer._tokenize_singleton_prompt(
127
128
129
130
131
            TokensPrompt(prompt_token_ids=inputs),
            tok_params,
        )

        return tok_prompt["prompt_token_ids"]
132

133
134
135
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
136
        tokenization_kwargs: dict[str, Any] | None = None,
137
    ) -> TokensInput | MultiModalInput:
138
        prompt_token_ids = self._truncate_inputs(
139
140
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )
141

142
        inputs: TokensInput | MultiModalInput
143
        if multi_modal_data := parsed_content.get("multi_modal_data"):
144
145
            inputs = self._process_multimodal(
                prompt_token_ids,
146
                multi_modal_data,
147
                parsed_content.get("mm_processor_kwargs"),
148
                tokenization_kwargs=tokenization_kwargs,
149
                mm_uuids=parsed_content.get("multi_modal_uuids"),
150
            )
151
        else:
152
            inputs = tokens_input(prompt_token_ids)
153

154
155
        if prompt_text := parsed_content.get("prompt"):
            inputs["prompt"] = prompt_text
156
157
158
159
160
161
162
163
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
164
        tokenization_kwargs: dict[str, Any] | None = None,
165
    ) -> TokensInput | MultiModalInput:
166
167
        prompt_text = parsed_content["prompt"]

168
        inputs: TokensInput | MultiModalInput
169
        if multi_modal_data := parsed_content.get("multi_modal_data"):
170
171
            inputs = self._process_multimodal(
                prompt_text,
172
                multi_modal_data,
173
                parsed_content.get("mm_processor_kwargs") or {},
174
                tokenization_kwargs=tokenization_kwargs,
175
176
177
178
179
180
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
181
            inputs = tokens_input(prompt_token_ids)
182

183
184
        inputs["prompt"] = prompt_text

185
186
187
188
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
189

190
    @overload
191
    def _prompt_to_llm_inputs(
192
        self,
193
194
        prompt: EncoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
195
    ) -> EncoderInput: ...
196
197
198
199
200
201

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
202
    ) -> DecoderEngineInput: ...
203
204
205
206
207
208

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
209
    ) -> DecoderOnlyEngineInput: ...
210
211
212
213

    def _prompt_to_llm_inputs(
        self,
        prompt: SingletonDictPrompt,
214
        tokenization_kwargs: dict[str, Any] | None = None,
215
    ) -> SingletonInput:
216
217
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]
218

219
        if "prompt_token_ids" in prompt:
220
            return self._process_tokens(prompt)  # type: ignore[arg-type]
221
222

        if "prompt" in prompt:
223
            return self._process_text(
224
                prompt,  # type: ignore[arg-type]
225
                tokenization_kwargs=tokenization_kwargs,
226
            )
227

228
        assert_never(prompt)  # type: ignore[arg-type]
229

230
231
    def _process_encoder_decoder_prompt(
        self,
232
        prompt: EncoderDecoderDictPrompt,
233
        tokenization_kwargs: dict[str, Any] | None = None,
234
    ) -> EncoderDecoderInput:
235
236
        encoder_prompt = prompt["encoder_prompt"]
        decoder_prompt = prompt["decoder_prompt"]
237

Ekagra Ranjan's avatar
Ekagra Ranjan committed
238
239
240
241
242
243
244
245
246
        skip_decoder_start_token = False
        if self.renderer.mm_processor is not None:
            from vllm.multimodal.processing import EncDecMultiModalProcessor

            if isinstance(self.renderer.mm_processor, EncDecMultiModalProcessor):
                skip_decoder_start_token = (
                    self.renderer.mm_processor.skip_decoder_start_token
                )

247
248
        return build_enc_dec_input(
            encoder_input=self._prompt_to_llm_inputs(
249
                encoder_prompt,
250
                tokenization_kwargs=tokenization_kwargs,
251
            ),
252
            decoder_input=(
253
254
255
256
257
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
258
                )
259
            ),
260
            decoder_start_token_id=self.renderer.get_dec_start_token_id(),
Ekagra Ranjan's avatar
Ekagra Ranjan committed
261
            skip_decoder_start_token=skip_decoder_start_token,
262
        )
263
264
265

    def _process_decoder_only_prompt(
        self,
266
        prompt: DecoderOnlyDictPrompt,
267
        tokenization_kwargs: dict[str, Any] | None = None,
268
    ) -> DecoderOnlyEngineInput:
269
        return self._prompt_to_llm_inputs(
270
            prompt,
271
            tokenization_kwargs=tokenization_kwargs,
272
273
        )

274
    def preprocess(
275
        self,
276
        prompt: PromptType,
277
        tokenization_kwargs: dict[str, Any] | None = None,
278
    ) -> EngineInput:
279
        """Preprocess the input prompt."""
280
        if self.model_config.is_encoder_decoder:
281
            # Encoder-decoder model requires special mapping of
282
            # input prompts to encoder & decoder.
283
            return self._process_encoder_decoder_prompt(
284
                parse_enc_dec_prompt(prompt),
285
286
                tokenization_kwargs,
            )
287
288

        return self._process_decoder_only_prompt(
289
            parse_dec_only_prompt(prompt),
290
            tokenization_kwargs=tokenization_kwargs,
291
        )