"vscode:/vscode.git/clone" did not exist on "ffa443afedd3ffefb8dbe20607692950b78c1496"
renderer.py 14.9 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
5
import io
6
from abc import ABC, abstractmethod
7
from dataclasses import dataclass
8
9
from typing import Annotated, Optional, Union

10
11
import pybase64
import torch
12
13
14
from pydantic import Field

from vllm.config import ModelConfig
15
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
16
17
18
19
20
21
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@dataclass(frozen=True)
class RenderConfig:
    """Configuration to control how prompts are prepared."""

    max_length: Optional[int] = None
    """Maximum allowable total input token length. If provided,
    token inputs longer than this raise ``ValueError``."""

    truncate_prompt_tokens: Optional[int] = None
    """Number of tokens to keep. ``None`` means no truncation.
    ``0`` yields an empty list (and skips embeds).
    ``-1`` maps to ``model_config.max_model_len``."""

    add_special_tokens: Optional[bool] = True
    """Whether to add model-specific special tokens during tokenization."""

    cache_salt: Optional[str] = None
    """String to disambiguate prefix cache entries."""

    needs_detokenization: Optional[bool] = False
    """If True, detokenize IDs back to text for inclusion in outputs."""


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class BaseRenderer(ABC):
    """
    Base class for unified input processing and rendering.
    
    The Renderer serves as a unified input processor that consolidates
    tokenization, chat template formatting, and multimodal input handling
    into a single component.
    It converts high-level API requests (OpenAI-style JSON) into token IDs and
    multimodal features ready for engine consumption.
    
    Key responsibilities:
    - Convert text prompts to token sequences with proper special tokens
    - Apply chat templates and format conversations
    - Handle multimodal inputs (images, audio, etc.) when applicable
    - Manage prompt truncation and length validation
    - Provide clean separation between API layer and engine core
    """

    def __init__(
        self,
        model_config: ModelConfig,
        tokenizer: Optional[AnyTokenizer] = None,
    ):
        super().__init__()
        self.model_config = model_config
        self.tokenizer = tokenizer

    @abstractmethod
    async def render_prompt(
        self,
75
        *,
76
        prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
77
        config: "RenderConfig",
78
79
    ) -> list[EngineTokensPrompt]:
        """
80
81
82
83
84
85
86
87
88
89
90
91
        Convert text or token inputs into engine-ready TokensPrompt objects.

        This method accepts text or token inputs and produces a
        list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
        for the engine.

        Args:
            prompt_or_prompts: One of:
                - ``str``: Single text prompt.
                - ``list[str]``: Batch of text prompts.
                - ``list[int]``: Single pre-tokenized sequence.
                - ``list[list[int]]``: Batch of pre-tokenized sequences.
92
93
            config: Render configuration controlling how prompts are prepared
                (e.g., tokenization and length handling). 
94
95
96
97
98
99
100
101
102
103
104
105

        Returns:
            list[EngineTokensPrompt]: Engine-ready token prompts.

        Raises:
            ValueError: If input formats are invalid or length limits exceeded.
        """
        raise NotImplementedError

    @abstractmethod
    async def render_prompt_and_embeds(
        self,
106
        *,
107
108
109
        prompt_or_prompts: Optional[Union[str, list[str], list[int],
                                          list[list[int]]]] = None,
        prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
110
        config: "RenderConfig",
111
112
113
    ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
        """
        Convert text/token and/or base64-encoded embeddings inputs into
114
        engine-ready prompt objects using a unified RenderConfig.
115
116
117
118
119

        At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
        provided and non-empty. If both are omitted or empty (e.g., empty
        string and empty list), a ``ValueError`` is raised.

120
        Args:
121
122
123
            prompt_or_prompts: Text or token inputs to include.
            prompt_embeds: Base64-encoded bytes (or list thereof) containing a
                torch-saved tensor to be used as prompt embeddings.
124
125
            config: Render configuration controlling how prompts are prepared
                (e.g., tokenization and length handling). 
126

127
        Returns:
128
129
130
            list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
                Engine-ready prompt objects.

131
        Raises:
132
133
134
            ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
                are omitted or empty (decoder prompt cannot be empty), or if
                length limits are exceeded.
135
136
137
        """
        raise NotImplementedError

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    @classmethod
    def load_prompt_embeds(
        cls,
        prompt_embeds: Union[bytes, list[bytes]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
        cache_salt: Optional[str] = None,
    ) -> list[EngineEmbedsPrompt]:
        """Load and validate base64-encoded embeddings into prompt objects."""

        def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
            tensor = torch.load(
                io.BytesIO(pybase64.b64decode(embed, validate=True)),
                weights_only=True,
                map_location=torch.device("cpu"),
            )
            assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
                torch.float32,
                torch.bfloat16,
                torch.float16,
            )
            tensor = tensor.to_dense()
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
            embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
            if cache_salt is not None:
                embeds_prompt["cache_salt"] = cache_salt
            return embeds_prompt

        if isinstance(prompt_embeds, list):
            return [_load_and_validate_embed(embed) for embed in prompt_embeds]
171
172

        return [_load_and_validate_embed(prompt_embeds)]
173

174
175
176
177
178
179
180
181
182
183
184

class CompletionRenderer(BaseRenderer):

    def __init__(
        self,
        model_config: ModelConfig,
        tokenizer: Optional[AnyTokenizer] = None,
        async_tokenizer_pool: Optional[dict[AnyTokenizer,
                                            AsyncMicrobatchTokenizer]] = None,
    ):
        super().__init__(model_config, tokenizer)
185
        self.async_tokenizer_pool = async_tokenizer_pool
186
187
188
189
        self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None

    async def render_prompt(
        self,
190
        *,
191
        prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
192
        config: "RenderConfig",
193
194
195
196
197
198
    ) -> list[EngineTokensPrompt]:
        """Implementation of prompt rendering for completion-style requests.
        
        Uses async tokenizer pooling for improved performance. See base class
        for detailed parameter documentation.
        """
199
        truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
200
            config.truncate_prompt_tokens, config.max_length)
201
202
        if truncate_prompt_tokens == 0:
            return []
203
204
205
206

        # Parse and batch the input prompts
        batch_inputs = parse_and_batch_prompt(prompt_or_prompts)

207
        tasks = []
208
209
210
        for prompt_input in batch_inputs:
            if prompt_input["is_tokens"] is True:
                # Token input
211
212
213
214
215
216
217
                # Note: detokenization is needed when echo is enabled,
                # where the input token IDs are decoded back to text.
                task = self._maybe_detokenize(prompt_input["content"],
                                              config.max_length,
                                              truncate_prompt_tokens,
                                              config.cache_salt,
                                              config.needs_detokenization)
218
219
            else:
                # Text input
220
221
222
223
224
225
                task = self._tokenize(prompt_input["content"],
                                      config.max_length,
                                      truncate_prompt_tokens,
                                      config.add_special_tokens,
                                      config.cache_salt)
            tasks.append(task)
226
227

        # Wait for all text tokenization to finish
228
229
230
        if tasks:
            tokenized_text_prompts = await asyncio.gather(*tasks)
            return tokenized_text_prompts
231

232
233
234
235
        return []

    async def render_prompt_and_embeds(
        self,
236
        *,
237
238
239
        prompt_or_prompts: Optional[Union[str, list[str], list[int],
                                          list[list[int]]]] = None,
        prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
240
        config: "RenderConfig",
241
242
243
244
245
246
    ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
        """
        Render text/token prompts and/or precomputed embedding prompts. At
        least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
        """
        truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
247
            config.truncate_prompt_tokens, config.max_length)
248
249
250
251
252
253
254
255
        if truncate_prompt_tokens == 0:
            return []

        rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []

        if prompt_embeds is not None:
            rendered.extend(
                self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
256
                                        config.cache_salt))
257
258
259
260
261
        if prompt_or_prompts is None or prompt_or_prompts == "":
            return rendered

        token_prompts = await self.render_prompt(
            prompt_or_prompts=prompt_or_prompts,
262
            config=config,
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        )
        rendered.extend(token_prompts)

        return rendered

    def _validate_and_normalize_truncate_tokens(
        self,
        truncate_prompt_tokens: Optional[int],
        max_length: Optional[int],
    ) -> Optional[int]:
        """Validate and normalize truncate_prompt_tokens parameter."""
        if truncate_prompt_tokens is None:
            return None

        if truncate_prompt_tokens == 0:
            return 0

        if truncate_prompt_tokens < 0:
            truncate_prompt_tokens = self.model_config.max_model_len

283
        if max_length is not None and truncate_prompt_tokens > max_length:  # type: ignore[operator]
284
285
286
287
288
289
            raise ValueError(
                f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
                f"cannot be greater than max_length ({max_length}). "
                f"Please select a smaller truncation size.")

        return truncate_prompt_tokens
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

    def _maybe_apply_truncation(
            self, token_ids: list[int],
            truncate_prompt_tokens: Optional[int]) -> list[int]:
        """Apply truncation to token sequence."""
        if truncate_prompt_tokens is None:
            return token_ids
        if truncate_prompt_tokens >= len(token_ids):
            return token_ids

        return token_ids[-truncate_prompt_tokens:]

    async def _tokenize(
        self,
        text: str,
        max_length: Optional[int],
        truncate_prompt_tokens: Optional[int],
        add_special_tokens: Optional[bool],
        cache_salt: Optional[str],
    ) -> EngineTokensPrompt:
        """Tokenize text input asynchronously."""
        async_tokenizer = self._get_async_tokenizer()

        # Handle encoder-specific preprocessing
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            text = text.lower()

        # Tokenize texts
        if truncate_prompt_tokens is None:
            encoded = await async_tokenizer(
                text, add_special_tokens=add_special_tokens)
        else:
            encoded = await async_tokenizer(
                text,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=truncate_prompt_tokens)

        return self._create_tokens_prompt(encoded.input_ids, max_length,
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
                                          cache_salt, text)

    async def _maybe_detokenize(
        self,
        token_ids: list[int],
        max_length: Optional[int],
        truncate_prompt_tokens: Optional[int],
        cache_salt: Optional[str],
        needs_detokenization: Optional[bool] = False,
    ) -> EngineTokensPrompt:
        """Optionally detokenize token IDs and build a tokens prompt."""
        token_ids = self._maybe_apply_truncation(token_ids,
                                                 truncate_prompt_tokens)

        prompt = None
        if needs_detokenization is True:
            async_tokenizer = self._get_async_tokenizer()
            prompt = await async_tokenizer.decode(token_ids)

        return self._create_tokens_prompt(token_ids=token_ids,
                                          max_length=max_length,
                                          cache_salt=cache_salt,
                                          prompt=prompt)
354
355
356

    def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
        """Get or create async tokenizer using shared pool."""
357
358
359
360
361
        async_tokenizer = self.async_tokenizer
        if async_tokenizer is not None:
            return async_tokenizer

        tokenizer = self.tokenizer
362
363
364
365
        if self.tokenizer is None:
            raise ValueError(
                "No tokenizer available for text input processing")

366
367
368
369
370
371
372
373
374
        if self.async_tokenizer_pool is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
        else:
            async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
            if async_tokenizer is None:
                async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
                self.async_tokenizer_pool[tokenizer] = async_tokenizer
        self.async_tokenizer = async_tokenizer
        return async_tokenizer
375
376
377
378
379
380

    def _create_tokens_prompt(
        self,
        token_ids: list[int],
        max_length: Optional[int] = None,
        cache_salt: Optional[str] = None,
381
        prompt: Optional[str] = None,
382
383
384
385
    ) -> EngineTokensPrompt:
        """Create validated EngineTokensPrompt."""
        if max_length is not None and len(token_ids) > max_length:
            raise ValueError(
386
                f"This model's maximum context length is {max_length} tokens. "
387
388
389
390
391
392
                f"However, your request has {len(token_ids)} input tokens. "
                "Please reduce the length of the input messages.")

        tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
        if cache_salt is not None:
            tokens_prompt["cache_salt"] = cache_salt
393
394
        if prompt is not None:
            tokens_prompt["prompt"] = prompt
395
        return tokens_prompt