renderer.py 14.9 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
5
import io
6
from abc import ABC, abstractmethod
7
from dataclasses import dataclass
8
from typing import Annotated
9

10
11
import pybase64
import torch
12
13
14
from pydantic import Field

from vllm.config import ModelConfig
15
from vllm.entrypoints.openai.protocol import VLLMValidationError
16
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
17
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
18
from vllm.tokenizers import TokenizerLike
19
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
20
21


22
23
24
25
@dataclass(frozen=True)
class RenderConfig:
    """Configuration to control how prompts are prepared."""

26
    max_length: int | None = None
27
    """Maximum allowable total input token length. If provided,
28
    token inputs longer than this raise `ValueError`."""
29

30
    truncate_prompt_tokens: int | None = None
31
32
33
    """Number of tokens to keep. `None` means no truncation.
    `0` yields an empty list (and skips embeds).
    `-1` maps to `model_config.max_model_len`."""
34

35
    add_special_tokens: bool = True
36
37
    """Whether to add model-specific special tokens during tokenization."""

38
    cache_salt: str | None = None
39
40
    """String to disambiguate prefix cache entries."""

41
    needs_detokenization: bool | None = False
42
43
    """If True, detokenize IDs back to text for inclusion in outputs."""

44
    def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        """Validate and normalize `truncate_prompt_tokens` parameter."""
        truncate_prompt_tokens = self.truncate_prompt_tokens
        if truncate_prompt_tokens is None:
            return None

        if truncate_prompt_tokens == 0:
            return 0

        if truncate_prompt_tokens < 0:
            truncate_prompt_tokens = model_config.max_model_len

        max_length = self.max_length
        if max_length is not None and truncate_prompt_tokens > max_length:  # type: ignore[operator]
            raise ValueError(
                f"{truncate_prompt_tokens=} cannot be greater than "
60
61
                f"{max_length=}. Please select a smaller truncation size."
            )
62
63
64

        return truncate_prompt_tokens

65

66
67
68
class BaseRenderer(ABC):
    """
    Base class for unified input processing and rendering.
69

70
71
72
73
74
    The Renderer serves as a unified input processor that consolidates
    tokenization, chat template formatting, and multimodal input handling
    into a single component.
    It converts high-level API requests (OpenAI-style JSON) into token IDs and
    multimodal features ready for engine consumption.
75

76
77
78
79
80
81
82
83
84
85
86
    Key responsibilities:
    - Convert text prompts to token sequences with proper special tokens
    - Apply chat templates and format conversations
    - Handle multimodal inputs (images, audio, etc.) when applicable
    - Manage prompt truncation and length validation
    - Provide clean separation between API layer and engine core
    """

    def __init__(
        self,
        model_config: ModelConfig,
87
        tokenizer: TokenizerLike | None = None,
88
89
90
91
92
93
94
95
    ):
        super().__init__()
        self.model_config = model_config
        self.tokenizer = tokenizer

    @abstractmethod
    async def render_prompt(
        self,
96
        *,
97
        prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
98
        config: RenderConfig,
99
    ) -> list[TokensPrompt]:
100
        """
101
102
103
104
105
106
107
108
        Convert text or token inputs into engine-ready TokensPrompt objects.

        This method accepts text or token inputs and produces a
        list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
        for the engine.

        Args:
            prompt_or_prompts: One of:
109
110
111
112
                - `str`: Single text prompt.
                - `list[str]`: Batch of text prompts.
                - `list[int]`: Single pre-tokenized sequence.
                - `list[list[int]]`: Batch of pre-tokenized sequences.
113
            config: Render configuration controlling how prompts are prepared
114
                (e.g., tokenization and length handling).
115
116

        Returns:
117
            list[TokensPrompt]: Engine-ready token prompts.
118
119
120
121
122
123
124
125
126

        Raises:
            ValueError: If input formats are invalid or length limits exceeded.
        """
        raise NotImplementedError

    @abstractmethod
    async def render_prompt_and_embeds(
        self,
127
        *,
128
129
        prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
        prompt_embeds: bytes | list[bytes] | None = None,
130
        config: RenderConfig,
131
    ) -> list[TokensPrompt | EmbedsPrompt]:
132
133
        """
        Convert text/token and/or base64-encoded embeddings inputs into
134
        engine-ready prompt objects using a unified RenderConfig.
135

136
        At least one of `prompt_or_prompts` or `prompt_embeds` must be
137
        provided and non-empty. If both are omitted or empty (e.g., empty
138
        string and empty list), a `ValueError` is raised.
139

140
        Args:
141
142
143
            prompt_or_prompts: Text or token inputs to include.
            prompt_embeds: Base64-encoded bytes (or list thereof) containing a
                torch-saved tensor to be used as prompt embeddings.
144
            config: Render configuration controlling how prompts are prepared
145
                (e.g., tokenization and length handling).
146

147
        Returns:
148
            list[Union[TokensPrompt, EmbedsPrompt]]:
149
150
                Engine-ready prompt objects.

151
        Raises:
152
            ValueError: If both `prompt_or_prompts` and `prompt_embeds`
153
154
                are omitted or empty (decoder prompt cannot be empty), or if
                length limits are exceeded.
155
156
157
        """
        raise NotImplementedError

158
    def load_prompt_embeds(
159
        self,
160
161
162
        prompt_embeds: bytes | list[bytes],
        truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
        cache_salt: str | None = None,
163
    ) -> list[EmbedsPrompt]:
164
        """Load and validate base64-encoded embeddings into prompt objects."""
165
        if not self.model_config.enable_prompt_embeds:
166
167
168
            raise VLLMValidationError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
                parameter="prompt_embeds",
169
            )
170

171
        def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            # Enable sparse tensor integrity checks to prevent out-of-bounds
            # writes from maliciously crafted tensors
            with torch.sparse.check_sparse_tensor_invariants():
                tensor = torch.load(
                    io.BytesIO(pybase64.b64decode(embed, validate=True)),
                    weights_only=True,
                    map_location=torch.device("cpu"),
                )
                assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
                    torch.float32,
                    torch.bfloat16,
                    torch.float16,
                )
                tensor = tensor.to_dense()
186
187
188
189
190
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
191
            embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
192
193
194
195
196
197
            if cache_salt is not None:
                embeds_prompt["cache_salt"] = cache_salt
            return embeds_prompt

        if isinstance(prompt_embeds, list):
            return [_load_and_validate_embed(embed) for embed in prompt_embeds]
198
199

        return [_load_and_validate_embed(prompt_embeds)]
200

201
202
203
204
205

class CompletionRenderer(BaseRenderer):
    def __init__(
        self,
        model_config: ModelConfig,
206
207
        tokenizer: TokenizerLike | None = None,
        async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
208
        | None = None,
209
210
    ):
        super().__init__(model_config, tokenizer)
211
        self.async_tokenizer_pool = async_tokenizer_pool
212
        self.async_tokenizer: AsyncMicrobatchTokenizer | None = None
213
214
215

    async def render_prompt(
        self,
216
        *,
217
        prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
218
        config: RenderConfig,
219
    ) -> list[TokensPrompt]:
220
        """Implementation of prompt rendering for completion-style requests.
221

222
223
224
        Uses async tokenizer pooling for improved performance. See base class
        for detailed parameter documentation.
        """
225
        truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
226
227
        if truncate_prompt_tokens == 0:
            return []
228

229
230
231
232
233
234
235
236
        tasks = (
            self._create_prompt(
                prompt_input,
                config=config,
                truncate_prompt_tokens=truncate_prompt_tokens,
            )
            for prompt_input in parse_raw_prompts(prompt_or_prompts)
        )
237
238

        return await asyncio.gather(*tasks)
239
240
241

    async def render_prompt_and_embeds(
        self,
242
        *,
243
244
        prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
        prompt_embeds: bytes | list[bytes] | None = None,
245
        config: RenderConfig,
246
    ) -> list[TokensPrompt | EmbedsPrompt]:
247
248
249
250
        """
        Render text/token prompts and/or precomputed embedding prompts. At
        least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
        """
251
        truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
252
253
254
        if truncate_prompt_tokens == 0:
            return []

255
        rendered: list[TokensPrompt | EmbedsPrompt] = []
256
257
258

        if prompt_embeds is not None:
            rendered.extend(
259
260
261
262
                self.load_prompt_embeds(
                    prompt_embeds, truncate_prompt_tokens, config.cache_salt
                )
            )
263
264
265
266
267
        if prompt_or_prompts is None or prompt_or_prompts == "":
            return rendered

        token_prompts = await self.render_prompt(
            prompt_or_prompts=prompt_or_prompts,
268
            config=config,
269
270
271
272
273
        )
        rendered.extend(token_prompts)

        return rendered

274
    def _maybe_apply_truncation(
275
        self, token_ids: list[int], truncate_prompt_tokens: int | None
276
    ) -> list[int]:
277
278
279
280
281
282
283
284
        """Apply truncation to token sequence."""
        if truncate_prompt_tokens is None:
            return token_ids
        if truncate_prompt_tokens >= len(token_ids):
            return token_ids

        return token_ids[-truncate_prompt_tokens:]

285
286
    async def _create_prompt(
        self,
287
        prompt_input: TextPrompt | TokensPrompt,
288
        config: RenderConfig,
289
        truncate_prompt_tokens: int | None,
290
    ) -> TokensPrompt:
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)

        if prompt_token_ids is not None:
            # NOTE: detokenization is needed when echo is enabled,
            # where the input token IDs are decoded back to text.
            return await self._create_prompt_from_token_ids(
                prompt_token_ids,
                config.max_length,
                truncate_prompt_tokens,
                config.cache_salt,
                config.needs_detokenization,
            )

        if prompt is not None:
            return await self._create_prompt_from_text(
                prompt,
                config.max_length,
                truncate_prompt_tokens,
                config.add_special_tokens,
                config.cache_salt,
            )

        # TODO: Also handle embeds prompt using this method
        raise NotImplementedError

    async def _create_prompt_from_text(
317
318
        self,
        text: str,
319
320
        max_length: int | None,
        truncate_prompt_tokens: int | None,
321
        add_special_tokens: bool,
322
        cache_salt: str | None,
323
    ) -> TokensPrompt:
324
325
326
327
        """Tokenize text input asynchronously."""
        async_tokenizer = self._get_async_tokenizer()

        # Handle encoder-specific preprocessing
328
329
330
331
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
332
333
334
335
            text = text.lower()

        # Tokenize texts
        if truncate_prompt_tokens is None:
336
            encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens)
337
338
339
340
341
        else:
            encoded = await async_tokenizer(
                text,
                add_special_tokens=add_special_tokens,
                truncation=True,
342
343
                max_length=truncate_prompt_tokens,
            )
344

345
346
347
        return self._create_tokens_prompt(
            encoded.input_ids, max_length, cache_salt, text
        )
348

349
    async def _create_prompt_from_token_ids(
350
351
        self,
        token_ids: list[int],
352
353
354
355
        max_length: int | None,
        truncate_prompt_tokens: int | None,
        cache_salt: str | None,
        needs_detokenization: bool | None = False,
356
    ) -> TokensPrompt:
357
        """Optionally detokenize token IDs and build a tokens prompt."""
358
        token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
359
360

        prompt = None
361
        if needs_detokenization:
362
363
364
            async_tokenizer = self._get_async_tokenizer()
            prompt = await async_tokenizer.decode(token_ids)

365
366
367
368
369
370
        return self._create_tokens_prompt(
            token_ids=token_ids,
            max_length=max_length,
            cache_salt=cache_salt,
            prompt=prompt,
        )
371
372
373

    def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
        """Get or create async tokenizer using shared pool."""
374
375
376
377
378
        async_tokenizer = self.async_tokenizer
        if async_tokenizer is not None:
            return async_tokenizer

        tokenizer = self.tokenizer
379
        if tokenizer is None:
380
            raise ValueError("No tokenizer available for text input processing")
381

382
383
384
385
386
387
388
389
390
        if self.async_tokenizer_pool is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
        else:
            async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
            if async_tokenizer is None:
                async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
                self.async_tokenizer_pool[tokenizer] = async_tokenizer
        self.async_tokenizer = async_tokenizer
        return async_tokenizer
391
392
393
394

    def _create_tokens_prompt(
        self,
        token_ids: list[int],
395
396
397
        max_length: int | None = None,
        cache_salt: str | None = None,
        prompt: str | None = None,
398
399
    ) -> TokensPrompt:
        """Create validated TokensPrompt."""
400
        if max_length is not None and len(token_ids) > max_length:
401
            raise VLLMValidationError(
402
                f"This model's maximum context length is {max_length} tokens. "
403
                f"However, your request has {len(token_ids)} input tokens. "
404
405
406
                "Please reduce the length of the input messages.",
                parameter="input_tokens",
                value=len(token_ids),
407
            )
408

409
        tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
410
411
        if cache_salt is not None:
            tokens_prompt["cache_salt"] = cache_salt
412
413
        if prompt is not None:
            tokens_prompt["prompt"] = prompt
414
        return tokens_prompt