preprocess.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any, overload
6
7
8

from typing_extensions import assert_never

9
from vllm.config import VllmConfig
10
from vllm.logger import init_logger
11
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
12
13
14
15
16
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalInputs,
    MultiModalUUIDDict,
)
17
from vllm.renderers import BaseRenderer, renderer_from_config
18
19
20
21
22
23
24
25
26
27
from vllm.renderers.inputs import (
    DecoderDictPrompt,
    DecoderOnlyDictPrompt,
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDictPrompt,
    SingletonDictPrompt,
    TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
28
from vllm.tokenizers import TokenizerLike
29

30
from .data import (
31
    DecoderInputs,
32
33
34
35
    DecoderOnlyInputs,
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
36
    EncoderInputs,
37
38
39
40
41
42
43
44
45
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
    embeds_inputs,
    token_inputs,
)
46
47
48
49
50
51
52

logger = init_logger(__name__)


class InputPreprocessor:
    def __init__(
        self,
53
        vllm_config: VllmConfig,
54
        renderer: BaseRenderer | None = None,
55
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
56
57
58
    ) -> None:
        super().__init__()

59
60
        self.model_config = vllm_config.model_config
        self.renderer = renderer or renderer_from_config(vllm_config)
61
        self.mm_registry = mm_registry
62

63
64
65
    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer
66

67
68
    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()
69

70
    def get_decoder_start_token_id(self) -> int:
71
        """
72
        Obtain the decoder start token id employed by an encoder/decoder
73
        model. Raises an error if it is not available.
74
        """
75
76
77
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )
78

79
        if dec_start_token_id is None:
80
            logger.warning_once(
81
82
                "Falling back on <BOS> for decoder start token id "
                "because decoder start token id is not available."
83
            )
84
            dec_start_token_id = self.renderer.get_bos_token_id()
85

86
87
        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")
88

89
        return dec_start_token_id
90

91
    def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
92
93
94
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

95
96
97
98
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
99
100
101
102
103
104
105
106
107
108
109

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """
        decoder_start_token_id = self.get_decoder_start_token_id()

110
111
112
113
        if (
            len(decoder_input_ids) == 0
            or decoder_input_ids[0] != decoder_start_token_id
        ):
114
115
116
117
118
119
120
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _tokenize_prompt(
        self,
        prompt: str,
121
        tokenization_kwargs: dict[str, Any] | None = None,
122
    ) -> list[int]:
123
124
125
126
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
127
        renderer = self.renderer
128

129
130
131
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
132

133
134
135
136
        tok_prompt = renderer.tokenize_prompt(
            TextPrompt(prompt=prompt),
            tok_params,
        )
137

138
        return tok_prompt["prompt_token_ids"]
139

140
141
    def _process_multimodal(
        self,
142
        prompt: str | list[int],
143
        mm_data: MultiModalDataDict,
144
145
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None = None,
146
        *,
147
        mm_uuids: MultiModalUUIDDict | None = None,
148
    ) -> MultiModalInputs:
149
150
151
152
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
153
        mm_processor = self.renderer.get_mm_processor()
154

155
156
157
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

158
        mm_items = mm_processor.info.parse_mm_data(mm_data)
159
160

        return mm_processor.apply(
161
            prompt,
162
            mm_items,
163
164
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
165
            mm_uuids=mm_uuids,
166
        )
167

168
169
170
171
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
172
        if not self.model_config.enable_prompt_embeds:
173
174
175
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )
176
177

        prompt_embeds = parsed_content["prompt_embeds"]
178

179
180
181
182
183
184
185
186
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
187
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
188

189
190
191
192
193
        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

194
195
196
        return embeds_inputs(
            prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
        )
197

198
    def _truncate_inputs(
199
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
200
    ) -> list[int]:
201
        renderer = self.renderer
202

203
204
205
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
206

207
208
209
210
211
212
        tok_prompt = renderer.tokenize_prompt(
            TokensPrompt(prompt_token_ids=inputs),
            tok_params,
        )

        return tok_prompt["prompt_token_ids"]
213

214
215
216
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
217
        tokenization_kwargs: dict[str, Any] | None = None,
218
        *,
219
220
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
221
        prompt_token_ids = self._truncate_inputs(
222
223
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )
224

225
        inputs: TokenInputs | MultiModalInputs
226
        if multi_modal_data := parsed_content.get("multi_modal_data"):
227
228
            inputs = self._process_multimodal(
                prompt_token_ids,
229
                multi_modal_data,
230
                parsed_content.get("mm_processor_kwargs") or {},
231
                tokenization_kwargs=tokenization_kwargs,
232
                mm_uuids=mm_uuids,
233
            )
234
        else:
235
            inputs = token_inputs(prompt_token_ids)
236
237
238
239
240
241
242
243
244

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
245
        tokenization_kwargs: dict[str, Any] | None = None,
246
        *,
247
248
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
249
250
        prompt_text = parsed_content["prompt"]

251
        inputs: TokenInputs | MultiModalInputs
252
        if multi_modal_data := parsed_content.get("multi_modal_data"):
253
254
            inputs = self._process_multimodal(
                prompt_text,
255
                multi_modal_data,
256
                parsed_content.get("mm_processor_kwargs") or {},
257
                tokenization_kwargs=tokenization_kwargs,
258
                mm_uuids=mm_uuids,
259
260
261
262
263
264
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
265
            inputs = token_inputs(prompt_token_ids)
266
267
268
269
270

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
271

272
    @overload
273
    def _prompt_to_llm_inputs(
274
        self,
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        prompt: EncoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> EncoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderOnlyInputs: ...

    def _prompt_to_llm_inputs(
        self,
        prompt: SingletonDictPrompt,
302
        tokenization_kwargs: dict[str, Any] | None = None,
303
        *,
304
        mm_uuids: MultiModalUUIDDict | None = None,
305
    ) -> SingletonInputs:
306
307
        """
        Extract the singleton inputs from a prompt.
308
309
310

        Arguments:

311
        * prompt: single encoder or decoder input prompt
312
313
314

        Returns:

315
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
316
        """
317
318
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]
319

320
        if "prompt_token_ids" in prompt:
321
            return self._process_tokens(
322
                prompt,  # type: ignore[arg-type]
323
                mm_uuids=mm_uuids,
324
            )
325
326

        if "prompt" in prompt:
327
            return self._process_text(
328
                prompt,  # type: ignore[arg-type]
329
                tokenization_kwargs=tokenization_kwargs,
330
                mm_uuids=mm_uuids,
331
            )
332

333
        assert_never(prompt)  # type: ignore[arg-type]
334

335
    def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
336
        if inputs["type"] == "embeds":
337
338
339
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )
340

341
342
343
344
        if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
            raise RuntimeError(
                "You should register an encoder-decoder "
                "multi-modal processor for encoder-decoder models."
345
            )
346

347
        return inputs  # type: ignore[return-value]
348

349
    def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
350
        if inputs["type"] == "embeds":
351
352
353
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )
354

355
        return inputs
356

357
358
359
360
361
362
    def _build_enc_dec_inputs(
        self,
        encoder_inputs: SingletonInputs,
        decoder_inputs: SingletonInputs | None = None,
    ) -> EncoderDecoderInputs:
        enc_inputs = self._validate_enc_inputs(encoder_inputs)
363

364
365
366
367
368
369
370
        if decoder_inputs is None:
            dec_inputs: DecoderInputs = enc_inputs  # type: ignore[assignment]
        else:
            dec_inputs = self._validate_dec_inputs(decoder_inputs)

        enc_inputs_new: EncoderInputs
        dec_inputs_new: DecoderInputs
371

372
373
374
        if enc_inputs["type"] == "multimodal":
            enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
            dec_inputs_new = MultiModalInputs(
375
                type="multimodal",
376
377
378
379
                prompt_token_ids=dec_inputs["prompt_token_ids"],
                mm_kwargs=enc_inputs["mm_kwargs"],
                mm_hashes=enc_inputs["mm_hashes"],
                mm_placeholders=enc_inputs["mm_placeholders"],
380
            )
381
382
383
        elif enc_inputs["type"] == "token":
            enc_inputs_new = token_inputs(prompt_token_ids=[])
            dec_inputs_new = dec_inputs
384
        else:
385
386
387
388
389
390
391
            assert_never(enc_inputs)

        dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
            dec_inputs_new["prompt_token_ids"]
        )
        if cache_salt := enc_inputs.get("cache_salt"):
            dec_inputs_new["cache_salt"] = cache_salt
392

393
        return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
394

395
396
    def _process_encoder_decoder_prompt(
        self,
397
        prompt: EncoderDecoderDictPrompt,
398
        tokenization_kwargs: dict[str, Any] | None = None,
399
        *,
400
        mm_uuids: MultiModalUUIDDict | None = None,
401
    ) -> EncoderDecoderInputs:
402
        """
403
        For encoder/decoder models only:
404
405
406
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
407
408
409

        Arguments:

410
        * prompt: an input prompt
411
412
413

        Returns:

414
415
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
416
        """
417
418
        encoder_prompt = prompt["encoder_prompt"]
        decoder_prompt = prompt["decoder_prompt"]
419
420
421
422

        return self._build_enc_dec_inputs(
            encoder_inputs=self._prompt_to_llm_inputs(
                encoder_prompt,
423
                tokenization_kwargs=tokenization_kwargs,
424
                mm_uuids=mm_uuids,
425
426
427
428
429
430
431
            ),
            decoder_inputs=(
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
432
                )
433
434
            ),
        )
435
436
437

    def _process_decoder_only_prompt(
        self,
438
        prompt: DecoderOnlyDictPrompt,
439
        tokenization_kwargs: dict[str, Any] | None = None,
440
        *,
441
        mm_uuids: MultiModalUUIDDict | None = None,
442
    ) -> DecoderOnlyInputs:
443
        """
444
        For decoder-only models:
445
446
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
447
448
449

        Arguments:

450
        * prompt: input prompt
451
452
453

        Returns:

454
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
455
        """
456
        return self._prompt_to_llm_inputs(
457
            prompt,
458
            tokenization_kwargs=tokenization_kwargs,
459
            mm_uuids=mm_uuids,
460
461
        )

462
    def _preprocess(
463
        self,
464
        prompt: PromptType | DictPrompt | TokPrompt,
465
        tokenization_kwargs: dict[str, Any] | None = None,
466
        *,
467
        mm_uuids: MultiModalUUIDDict | None = None,
468
    ) -> ProcessorInputs:
469
        if self.model_config.is_encoder_decoder:
470
            # Encoder-decoder model requires special mapping of
471
            # input prompts to encoder & decoder.
472
            return self._process_encoder_decoder_prompt(
473
                parse_enc_dec_prompt(prompt),
474
                tokenization_kwargs,
475
                mm_uuids=mm_uuids,
476
            )
477
478

        return self._process_decoder_only_prompt(
479
            parse_dec_only_prompt(prompt),
480
            tokenization_kwargs=tokenization_kwargs,
481
            mm_uuids=mm_uuids,
482
483
        )

484
485
    def preprocess(
        self,
486
        prompt: PromptType | DictPrompt | TokPrompt,
487
        tokenization_kwargs: dict[str, Any] | None = None,
488
        *,
489
        mm_uuids: MultiModalUUIDDict | None = None,
490
491
    ) -> ProcessorInputs:
        """Preprocess the input prompt."""
492
        res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
493

494
        self.renderer.update_mm_cache_stats()
495
496

        return res