preprocess.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any, overload
6
7
8

from typing_extensions import assert_never

9
from vllm.config import VllmConfig
10
from vllm.inputs.data import build_enc_dec_inputs
11
from vllm.logger import init_logger
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
14
15
16
17
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalInputs,
    MultiModalUUIDDict,
)
18
from vllm.renderers import BaseRenderer, renderer_from_config
19
20
21
22
23
24
25
26
27
28
from vllm.renderers.inputs import (
    DecoderDictPrompt,
    DecoderOnlyDictPrompt,
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDictPrompt,
    SingletonDictPrompt,
    TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
29
from vllm.tokenizers import TokenizerLike
30

31
from .data import (
32
    DecoderInputs,
33
34
35
36
    DecoderOnlyInputs,
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
37
    EncoderInputs,
38
39
40
41
42
43
44
45
46
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
    embeds_inputs,
    token_inputs,
)
47
48
49
50
51
52
53

logger = init_logger(__name__)


class InputPreprocessor:
    def __init__(
        self,
54
        vllm_config: VllmConfig,
55
        renderer: BaseRenderer | None = None,
56
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
57
58
59
    ) -> None:
        super().__init__()

60
61
        self.model_config = vllm_config.model_config
        self.renderer = renderer or renderer_from_config(vllm_config)
62
        self.mm_registry = mm_registry
63

64
65
66
    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer
67

68
69
    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()
70
71
72
73

    def _tokenize_prompt(
        self,
        prompt: str,
74
        tokenization_kwargs: dict[str, Any] | None = None,
75
    ) -> list[int]:
76
77
78
79
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
80
        renderer = self.renderer
81

82
83
84
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
85

86
87
88
89
        tok_prompt = renderer.tokenize_prompt(
            TextPrompt(prompt=prompt),
            tok_params,
        )
90

91
        return tok_prompt["prompt_token_ids"]
92

93
94
    def _process_multimodal(
        self,
95
        prompt: str | list[int],
96
        mm_data: MultiModalDataDict,
97
98
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None = None,
99
        *,
100
        mm_uuids: MultiModalUUIDDict | None = None,
101
    ) -> MultiModalInputs:
102
103
104
105
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
106
        mm_processor = self.renderer.get_mm_processor()
107

108
109
110
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

111
        mm_items = mm_processor.info.parse_mm_data(mm_data)
112
113

        return mm_processor.apply(
114
            prompt,
115
            mm_items,
116
117
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
118
            mm_uuids=mm_uuids,
119
        )
120

121
122
123
124
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
125
        if not self.model_config.enable_prompt_embeds:
126
127
128
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )
129
130

        prompt_embeds = parsed_content["prompt_embeds"]
131

132
133
134
135
136
137
138
139
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
140
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
141

142
143
144
145
146
        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

147
148
149
        return embeds_inputs(
            prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
        )
150

151
    def _truncate_inputs(
152
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
153
    ) -> list[int]:
154
        renderer = self.renderer
155

156
157
158
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
159

160
161
162
163
164
165
        tok_prompt = renderer.tokenize_prompt(
            TokensPrompt(prompt_token_ids=inputs),
            tok_params,
        )

        return tok_prompt["prompt_token_ids"]
166

167
168
169
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
170
        tokenization_kwargs: dict[str, Any] | None = None,
171
        *,
172
173
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
174
        prompt_token_ids = self._truncate_inputs(
175
176
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )
177

178
        inputs: TokenInputs | MultiModalInputs
179
        if multi_modal_data := parsed_content.get("multi_modal_data"):
180
181
            inputs = self._process_multimodal(
                prompt_token_ids,
182
                multi_modal_data,
183
                parsed_content.get("mm_processor_kwargs") or {},
184
                tokenization_kwargs=tokenization_kwargs,
185
                mm_uuids=mm_uuids,
186
            )
187
        else:
188
            inputs = token_inputs(prompt_token_ids)
189
190
191
192
193
194
195
196
197

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
198
        tokenization_kwargs: dict[str, Any] | None = None,
199
        *,
200
201
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
202
203
        prompt_text = parsed_content["prompt"]

204
        inputs: TokenInputs | MultiModalInputs
205
        if multi_modal_data := parsed_content.get("multi_modal_data"):
206
207
            inputs = self._process_multimodal(
                prompt_text,
208
                multi_modal_data,
209
                parsed_content.get("mm_processor_kwargs") or {},
210
                tokenization_kwargs=tokenization_kwargs,
211
                mm_uuids=mm_uuids,
212
213
214
215
216
217
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
218
            inputs = token_inputs(prompt_token_ids)
219
220
221
222
223

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
224

225
    @overload
226
    def _prompt_to_llm_inputs(
227
        self,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        prompt: EncoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> EncoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderOnlyInputs: ...

    def _prompt_to_llm_inputs(
        self,
        prompt: SingletonDictPrompt,
255
        tokenization_kwargs: dict[str, Any] | None = None,
256
        *,
257
        mm_uuids: MultiModalUUIDDict | None = None,
258
    ) -> SingletonInputs:
259
260
        """
        Extract the singleton inputs from a prompt.
261
262
263

        Arguments:

264
        * prompt: single encoder or decoder input prompt
265
266
267

        Returns:

268
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
269
        """
270
271
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]
272

273
        if "prompt_token_ids" in prompt:
274
            return self._process_tokens(
275
                prompt,  # type: ignore[arg-type]
276
                mm_uuids=mm_uuids,
277
            )
278
279

        if "prompt" in prompt:
280
            return self._process_text(
281
                prompt,  # type: ignore[arg-type]
282
                tokenization_kwargs=tokenization_kwargs,
283
                mm_uuids=mm_uuids,
284
            )
285

286
        assert_never(prompt)  # type: ignore[arg-type]
287

288
289
    def _process_encoder_decoder_prompt(
        self,
290
        prompt: EncoderDecoderDictPrompt,
291
        tokenization_kwargs: dict[str, Any] | None = None,
292
        *,
293
        mm_uuids: MultiModalUUIDDict | None = None,
294
    ) -> EncoderDecoderInputs:
295
        """
296
        For encoder/decoder models only:
297
298
299
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
300
301
302

        Arguments:

303
        * prompt: an input prompt
304
305
306

        Returns:

307
308
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
309
        """
310
311
        encoder_prompt = prompt["encoder_prompt"]
        decoder_prompt = prompt["decoder_prompt"]
312

313
        return build_enc_dec_inputs(
314
315
            encoder_inputs=self._prompt_to_llm_inputs(
                encoder_prompt,
316
                tokenization_kwargs=tokenization_kwargs,
317
                mm_uuids=mm_uuids,
318
319
320
321
322
323
324
            ),
            decoder_inputs=(
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
325
                )
326
            ),
327
            decoder_start_token_id=self.renderer.get_dec_start_token_id(),
328
        )
329
330
331

    def _process_decoder_only_prompt(
        self,
332
        prompt: DecoderOnlyDictPrompt,
333
        tokenization_kwargs: dict[str, Any] | None = None,
334
        *,
335
        mm_uuids: MultiModalUUIDDict | None = None,
336
    ) -> DecoderOnlyInputs:
337
        """
338
        For decoder-only models:
339
340
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
341
342
343

        Arguments:

344
        * prompt: input prompt
345
346
347

        Returns:

348
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
349
        """
350
        return self._prompt_to_llm_inputs(
351
            prompt,
352
            tokenization_kwargs=tokenization_kwargs,
353
            mm_uuids=mm_uuids,
354
355
        )

356
    def _preprocess(
357
        self,
358
        prompt: PromptType | DictPrompt | TokPrompt,
359
        tokenization_kwargs: dict[str, Any] | None = None,
360
        *,
361
        mm_uuids: MultiModalUUIDDict | None = None,
362
    ) -> ProcessorInputs:
363
        if self.model_config.is_encoder_decoder:
364
            # Encoder-decoder model requires special mapping of
365
            # input prompts to encoder & decoder.
366
            return self._process_encoder_decoder_prompt(
367
                parse_enc_dec_prompt(prompt),
368
                tokenization_kwargs,
369
                mm_uuids=mm_uuids,
370
            )
371
372

        return self._process_decoder_only_prompt(
373
            parse_dec_only_prompt(prompt),
374
            tokenization_kwargs=tokenization_kwargs,
375
            mm_uuids=mm_uuids,
376
377
        )

378
379
    def preprocess(
        self,
380
        prompt: PromptType | DictPrompt | TokPrompt,
381
        tokenization_kwargs: dict[str, Any] | None = None,
382
        *,
383
        mm_uuids: MultiModalUUIDDict | None = None,
384
385
    ) -> ProcessorInputs:
        """Preprocess the input prompt."""
386
        res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
387

388
        self.renderer.update_mm_cache_stats()
389
390

        return res