preprocess.py 9.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any, overload
6
7
8

from typing_extensions import assert_never

9
from vllm.config import VllmConfig
10
from vllm.inputs.data import build_enc_dec_inputs
11
from vllm.logger import init_logger
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
14
15
16
17
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalInputs,
    MultiModalUUIDDict,
)
18
from vllm.renderers import BaseRenderer, renderer_from_config
19
20
21
22
23
24
25
26
from vllm.renderers.inputs import (
    DecoderDictPrompt,
    DecoderOnlyDictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDictPrompt,
    SingletonDictPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
27
from vllm.tokenizers import TokenizerLike
28

29
from .data import (
30
    DecoderInputs,
31
32
33
34
    DecoderOnlyInputs,
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
35
    EncoderInputs,
36
37
38
39
40
41
42
43
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
    token_inputs,
)
44
45
46
47
48
49
50

logger = init_logger(__name__)


class InputPreprocessor:
    def __init__(
        self,
51
        vllm_config: VllmConfig,
52
        renderer: BaseRenderer | None = None,
53
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
54
55
56
    ) -> None:
        super().__init__()

57
58
        self.model_config = vllm_config.model_config
        self.renderer = renderer or renderer_from_config(vllm_config)
59
        self.mm_registry = mm_registry
60

61
62
63
    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer
64

65
66
    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()
67
68
69
70

    def _tokenize_prompt(
        self,
        prompt: str,
71
        tokenization_kwargs: dict[str, Any] | None = None,
72
    ) -> list[int]:
73
74
75
76
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
77
        renderer = self.renderer
78

79
80
81
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
82

83
        tok_prompt = renderer._tokenize_singleton_prompt(
84
85
86
            TextPrompt(prompt=prompt),
            tok_params,
        )
87

88
        return tok_prompt["prompt_token_ids"]
89

90
91
    def _process_multimodal(
        self,
92
        prompt: str | list[int],
93
        mm_data: MultiModalDataDict,
94
        mm_processor_kwargs: Mapping[str, object] | None = None,
95
        tokenization_kwargs: dict[str, Any] | None = None,
96
        *,
97
        mm_uuids: MultiModalUUIDDict | None = None,
98
    ) -> MultiModalInputs:
99
100
101
102
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
103
        return self.renderer._process_multimodal(
104
            prompt,
105
            mm_data,
106
            mm_uuids=mm_uuids,
107
            mm_processor_kwargs=mm_processor_kwargs,
108
109
            tokenization_kwargs=tokenization_kwargs,
        )
110

111
112
113
114
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
115
        return self.renderer._process_embeds(parsed_content)
116

117
    def _truncate_inputs(
118
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
119
    ) -> list[int]:
120
        renderer = self.renderer
121

122
123
124
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
125

126
        tok_prompt = renderer._tokenize_singleton_prompt(
127
128
129
130
131
            TokensPrompt(prompt_token_ids=inputs),
            tok_params,
        )

        return tok_prompt["prompt_token_ids"]
132

133
134
135
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
136
137
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> TokenInputs | MultiModalInputs:
138
        prompt_token_ids = self._truncate_inputs(
139
140
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )
141

142
        inputs: TokenInputs | MultiModalInputs
143
        if multi_modal_data := parsed_content.get("multi_modal_data"):
144
145
            inputs = self._process_multimodal(
                prompt_token_ids,
146
                multi_modal_data,
147
                parsed_content.get("mm_processor_kwargs"),
148
                tokenization_kwargs=tokenization_kwargs,
149
                mm_uuids=parsed_content.get("multi_modal_uuids"),
150
            )
151
        else:
152
            inputs = token_inputs(prompt_token_ids)
153

154
155
        if prompt_text := parsed_content.get("prompt"):
            inputs["prompt"] = prompt_text
156
157
158
159
160
161
162
163
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
164
165
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> TokenInputs | MultiModalInputs:
166
167
        prompt_text = parsed_content["prompt"]

168
        inputs: TokenInputs | MultiModalInputs
169
        if multi_modal_data := parsed_content.get("multi_modal_data"):
170
171
            inputs = self._process_multimodal(
                prompt_text,
172
                multi_modal_data,
173
                parsed_content.get("mm_processor_kwargs") or {},
174
                tokenization_kwargs=tokenization_kwargs,
175
176
177
178
179
180
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
181
            inputs = token_inputs(prompt_token_ids)
182

183
184
        inputs["prompt"] = prompt_text

185
186
187
188
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
189

190
    @overload
191
    def _prompt_to_llm_inputs(
192
        self,
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        prompt: EncoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> EncoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> DecoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> DecoderOnlyInputs: ...

    def _prompt_to_llm_inputs(
        self,
        prompt: SingletonDictPrompt,
214
        tokenization_kwargs: dict[str, Any] | None = None,
215
    ) -> SingletonInputs:
216
217
        """
        Extract the singleton inputs from a prompt.
218
219
220

        Arguments:

221
        * prompt: single encoder or decoder input prompt
222
223
224

        Returns:

225
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
226
        """
227
228
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]
229

230
        if "prompt_token_ids" in prompt:
231
            return self._process_tokens(prompt)  # type: ignore[arg-type]
232
233

        if "prompt" in prompt:
234
            return self._process_text(
235
                prompt,  # type: ignore[arg-type]
236
                tokenization_kwargs=tokenization_kwargs,
237
            )
238

239
        assert_never(prompt)  # type: ignore[arg-type]
240

241
242
    def _process_encoder_decoder_prompt(
        self,
243
        prompt: EncoderDecoderDictPrompt,
244
        tokenization_kwargs: dict[str, Any] | None = None,
245
    ) -> EncoderDecoderInputs:
246
        """
247
        For encoder/decoder models only:
248
249
250
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
251
252
253

        Arguments:

254
        * prompt: an input prompt
255
256
257

        Returns:

258
259
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
260
        """
261
262
        encoder_prompt = prompt["encoder_prompt"]
        decoder_prompt = prompt["decoder_prompt"]
263

264
        return build_enc_dec_inputs(
265
266
            encoder_inputs=self._prompt_to_llm_inputs(
                encoder_prompt,
267
                tokenization_kwargs=tokenization_kwargs,
268
269
270
271
272
273
274
            ),
            decoder_inputs=(
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
275
                )
276
            ),
277
            decoder_start_token_id=self.renderer.get_dec_start_token_id(),
278
        )
279
280
281

    def _process_decoder_only_prompt(
        self,
282
        prompt: DecoderOnlyDictPrompt,
283
        tokenization_kwargs: dict[str, Any] | None = None,
284
    ) -> DecoderOnlyInputs:
285
        """
286
        For decoder-only models:
287
288
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
289
290
291

        Arguments:

292
        * prompt: input prompt
293
294
295

        Returns:

296
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
297
        """
298
        return self._prompt_to_llm_inputs(
299
            prompt,
300
            tokenization_kwargs=tokenization_kwargs,
301
302
        )

303
    def preprocess(
304
        self,
305
        prompt: PromptType,
306
        tokenization_kwargs: dict[str, Any] | None = None,
307
    ) -> ProcessorInputs:
308
        """Preprocess the input prompt."""
309
        if self.model_config.is_encoder_decoder:
310
            # Encoder-decoder model requires special mapping of
311
            # input prompts to encoder & decoder.
312
            return self._process_encoder_decoder_prompt(
313
                parse_enc_dec_prompt(prompt),
314
315
                tokenization_kwargs,
            )
316
317

        return self._process_decoder_only_prompt(
318
            parse_dec_only_prompt(prompt),
319
            tokenization_kwargs=tokenization_kwargs,
320
        )