base.py 28 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import copy
5
import time
6
from abc import ABC, abstractmethod
7
from collections.abc import Mapping, Sequence
8
9
10
11
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload

from typing_extensions import TypeVar
12

13
from vllm.inputs import (
14
    EmbedsInput,
15
    EmbedsPrompt,
16
17
18
19
20
21
    EncoderDecoderInput,
    EngineInput,
    MultiModalDataDict,
    MultiModalInput,
    MultiModalUUIDDict,
    SingletonInput,
22
    TextPrompt,
23
    TokensInput,
24
    TokensPrompt,
25
26
27
    build_enc_dec_input,
    embeds_input,
    tokens_input,
28
)
29
from vllm.logger import init_logger
30
31
32
33
34
35
36
37
38
39
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.parse import (
    MultiModalDataItems,
    MultiModalUUIDItems,
    parse_mm_uuids,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
from vllm.multimodal.registry import MultiModalTimingRegistry
40
from vllm.tokenizers import TokenizerLike
41
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
42
from vllm.utils.counter import AtomicCounter
43
44
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
45
46

from .embed_utils import safe_load_prompt_embeds
47
48
49
50
from .inputs import (
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDecoderTokPrompt,
51
52
    SingletonDictPrompt,
    SingletonTokPrompt,
53
54
    TokPrompt,
)
55
from .inputs.preprocess import extract_target_prompt
56
from .params import ChatParams, TokenizeParams
57
58

if TYPE_CHECKING:
59
    from vllm.config import VllmConfig
60
61
62
63
64
    from vllm.entrypoints.chat_utils import (
        ChatCompletionMessageParam,
        ConversationMessage,
    )

65
66
logger = init_logger(__name__)

67

68
69
70
71
72
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)


class BaseRenderer(ABC, Generic[_T]):
    def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
73
74
        super().__init__()

75
        self.config = config
76
        self.model_config = config.model_config
77
        self.api_process_rank = config.parallel_config._api_process_rank
78

79
80
        self.tokenizer = tokenizer

81
82
83
        # Lazy initialization since offline LLM doesn't use async
        self._async_tokenizer: AsyncMicrobatchTokenizer | None = None

84
85
86
87
88
        self.mm_processor: BaseMultiModalProcessor | None = None
        self._mm_cache_stats: MultiModalCacheStats | None = None
        if config.model_config.is_multimodal_model:
            mm_processor_cache = mm_registry.processor_cache_from_config(config)

89
90
91
92
93
94
95
            # Deep-copy the tokenizer so the multimodal processor gets its
            # own Rust tokenizer backend.  Without this, concurrent access
            # from AsyncMicrobatchTokenizer and call_hf_processor causes
            # "RuntimeError: Already borrowed" from the Rust RefCell.
            # See: https://github.com/huggingface/tokenizers/issues/537
            mm_tokenizer = copy.deepcopy(tokenizer)

96
97
98
            with set_default_torch_num_threads():
                self.mm_processor = mm_registry.create_processor(
                    config.model_config,
99
                    tokenizer=mm_tokenizer,
100
101
102
103
104
                    cache=mm_processor_cache,
                )

            if mm_processor_cache:
                self._mm_cache_stats = MultiModalCacheStats()
105

106
107
108
            # This is used to generate internal request ID for MM processing
            # It has no relation to the request ID for engine core
            self._mm_req_counter = AtomicCounter()
109
110
111
            self._mm_timing_registry = MultiModalTimingRegistry(
                config.observability_config
            )
112

113
    def get_tokenizer(self) -> _T:
114
115
116
117
118
119
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

120
    def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
121
        if self._async_tokenizer is None:
122
123
124
125
            self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())

        return self._async_tokenizer

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def get_mm_processor(self) -> "BaseMultiModalProcessor":
        if self.mm_processor is None:
            raise ValueError("Multi-modal processor not available for text-only models")

        return self.mm_processor

    @property
    def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
        if self.mm_processor is None:
            return None

        return self.mm_processor.cache

    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        mm_cache_stats = self._mm_cache_stats
        if mm_cache_stats is None:
            return None

        self._mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def update_mm_cache_stats(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        mm_cache_stats = self._mm_cache_stats

        if mm_processor_cache and mm_cache_stats:
            delta = mm_processor_cache.make_stats(delta=True)
            mm_cache_stats.record(delta.total, delta.hits)

    def clear_mm_cache(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.clear_cache()

        if self._mm_cache_stats is not None:
            self._mm_cache_stats.reset = True

164
165
166
167
168
169
170
    def warmup(self, chat_params: ChatParams) -> None:
        """
        Warm up this renderer to avoid first-request latency.

        For chat requests:
        - Jinja2 template compilation
        """
171
172
        from vllm.entrypoints.chat_utils import ChatTemplateResolutionError

173
        try:
174
            logger.debug("Warming up chat template processing...")
175
176
177
178
179
            start_time = time.perf_counter()

            self.render_chat([[{"role": "user", "content": "warmup"}]], chat_params)

            elapsed = time.perf_counter() - start_time
180
            logger.debug("Chat template warmup completed in %.3fs", elapsed)
181
        except ChatTemplateResolutionError:
182
            logger.debug("This model does not support chat template.")
183
        except Exception:
184
            logger.warning("Chat template warmup failed", exc_info=True)
185
186
187
188
189
190
191
192
193
194

        if self.mm_processor:
            from vllm.multimodal.processing import TimingContext

            model_config = self.model_config
            mm_config = model_config.get_multimodal_config()
            processor = self.mm_processor
            mm_limits = processor.info.allowed_mm_limits

            try:
195
                logger.debug("Warming up multi-modal processing...")
196
197
198
199
200
201
202
203
                start_time = time.perf_counter()

                processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
                    seq_len=model_config.max_model_len,
                    mm_counts=dict.fromkeys(mm_limits, 1),
                    mm_options=mm_config.limit_per_prompt,
                )
                _ = processor.apply(
204
                    processor_inputs, timing_ctx=TimingContext(enabled=False)
205
206
207
208
209
                )

                elapsed = time.perf_counter() - start_time
                logger.info("Multi-modal warmup completed in %.3fs", elapsed)
            except Exception:
210
                logger.warning("Multi-modal warmup failed")
211
212
213
            finally:
                self.clear_mm_cache()

214
215
216
217
218
    def shutdown(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.close()

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def get_bos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for BOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.bos_token_id

    def get_eos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for EOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.eos_token_id

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def get_dec_start_token_id(self) -> int:
        """
        Obtain the decoder start token id employed by an encoder/decoder model,
        raising an error if it is not available.
        """
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )

        if dec_start_token_id is None:
            logger.warning_once(
                "Falling back on <BOS> for decoder start token id "
                "because decoder start token id is not available."
            )
            dec_start_token_id = self.get_bos_token_id()

        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")

        return dec_start_token_id

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    @cached_property
    def default_cmpl_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=True,
        )

    @cached_property
    def default_chat_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        )

288
    # Step 1: Convert raw inputs to prompts
289
    def render_prompt(
290
        self,
291
292
293
        prompt: DictPrompt | bytes,
    ) -> DictPrompt:
        if isinstance(prompt, bytes):
294
            embeds = safe_load_prompt_embeds(self.model_config, prompt)
295
            prompt = EmbedsPrompt(prompt_embeds=embeds)
296

297
        return prompt
298

299
    def render_prompts(
300
        self,
301
302
303
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        if len(prompts) == 0:
304
305
            raise ValueError("You must pass at least one prompt")

306
        return [self.render_prompt(prompt) for prompt in prompts]
307

308
    async def render_prompts_async(
309
        self,
310
311
312
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        return self.render_prompts(prompts)
313

314
    @abstractmethod
315
316
317
    def render_messages(
        self,
        messages: list["ChatCompletionMessageParam"],
318
        params: ChatParams,
319
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
320
321
322
323
324
        raise NotImplementedError

    async def render_messages_async(
        self,
        messages: list["ChatCompletionMessageParam"],
325
        params: ChatParams,
326
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
327
328
329
        return self.render_messages(messages, params)

    # Step 2: Tokenize prompts if necessary
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    def _tokenize_prompt(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt_token_ids = tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    async def _tokenize_prompt_async(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt_token_ids = await tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

    async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    @overload
    def _tokenize_singleton_prompt(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    def _tokenize_singleton_prompt(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    def _tokenize_singleton_prompt(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = self._tokenize_prompt(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = self._detokenize_prompt(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

    @overload
    async def _tokenize_singleton_prompt_async(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    async def _tokenize_singleton_prompt_async(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    async def _tokenize_singleton_prompt_async(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = await self._tokenize_prompt_async(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = await self._detokenize_prompt_async(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

430
431
432
433
434
435
    def _tokenize_enc_dec_prompt(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = (
436
            self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
437
438
439
            (
                None
                if prompt["decoder_prompt"] is None
440
                else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    async def _tokenize_enc_dec_prompt_async(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = await asyncio.gather(
455
            self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
456
457
458
            (
                asyncio.sleep(0)
                if prompt["decoder_prompt"] is None
459
460
461
                else self._tokenize_singleton_prompt_async(
                    prompt["decoder_prompt"], params
                )
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    def tokenize_prompt(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return self._tokenize_enc_dec_prompt(prompt, params)  # type: ignore[arg-type]

478
        return self._tokenize_singleton_prompt(prompt, params)
479
480
481

    def tokenize_prompts(
        self,
482
        prompts: Sequence[DictPrompt],
483
        params: TokenizeParams,
484
    ) -> list[TokPrompt]:
485
486
        return [self.tokenize_prompt(prompt, params) for prompt in prompts]

487
488
489
490
491
492
493
    async def tokenize_prompt_async(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return await self._tokenize_enc_dec_prompt_async(prompt, params)  # type: ignore[arg-type]
494

495
        return await self._tokenize_singleton_prompt_async(prompt, params)
496
497
498

    async def tokenize_prompts_async(
        self,
499
        prompts: Sequence[DictPrompt],
500
        params: TokenizeParams,
501
    ) -> list[TokPrompt]:
502
503
504
        return await asyncio.gather(
            *(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
        )
505
506
507
508

    # Step 3: Add extra keys to the prompts
    def _apply_prompt_extras(
        self,
509
        prompts: Sequence[TokPrompt],
510
511
512
513
514
515
        prompt_extras: dict[str, Any] | None,
    ):
        if not prompt_extras:
            return

        for prompt in prompts:
516
            target_prompt = extract_target_prompt(self.model_config, prompt)
517
518
            target_prompt.update(prompt_extras)  # type: ignore[arg-type]

519
520
521
    # Step 4: Convert to engine inputs
    def _validate_mm_uuids(
        self,
522
523
524
        mm_data: MultiModalDataDict,
        mm_data_items: MultiModalDataItems,
        mm_uuid_items: MultiModalUUIDItems,
525
    ) -> None:
526
527
528
        # NOTE: Keys corresponding to `None` in `mm_data` don't appear in
        # `mm_data_items`
        modalities = mm_data.keys() | mm_uuid_items.keys()
529
530

        for modality in modalities:
531
532
            data_items = mm_data_items.get(modality)
            uuid_items = mm_uuid_items.get(modality)
533

534
535
536
537
538
539
            if data_items is None:
                if uuid_items is None:
                    raise ValueError(
                        f"multi_modal_data[{modality!r}] is empty but "
                        f"multi_modal_uuids[{modality!r}] is missing."
                    )
540

541
            elif uuid_items is not None:
542
                if len(data_items) != len(uuid_items):
543
544
545
546
547
548
549
                    raise ValueError(
                        f"If given, multi_modal_uuids[{modality!r}] must have "
                        f"same length as multi_modal_data[{modality!r}], but "
                        f"got {len(uuid_items)} vs {len(data_items)}."
                    )

                for i, item in enumerate(data_items):
550
551
552
553
554
                    if item is None and uuid_items[i] is None:
                        raise ValueError(
                            f"multi_modal_data[{modality!r}][{i}] is empty but "
                            f"multi_modal_uuids[{modality!r}][{i}] is missing."
                        )
555
556
557

    def _process_mm_uuids(
        self,
558
559
560
        mm_data: MultiModalDataDict,
        mm_data_items: MultiModalDataItems,
        mm_uuid_items: MultiModalUUIDItems,
561
        mm_req_id: str,
562
    ) -> MultiModalUUIDItems:
563
564
565
566
567
568
569
570
571
572
573
574
        model_config = self.model_config

        # NOTE: When users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
        if (
            model_config.multimodal_config
            and model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.config.cache_config.enable_prefix_caching
        ):
575
            mm_uuid_items = {
576
                modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
577
                for modality, data_count in mm_data_items.get_all_counts().items()
578
579
            }

580
        self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
581

582
        return mm_uuid_items
583
584
585
586
587

    # TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
    def _process_multimodal(
        self,
        prompt: list[int] | str,
588
589
        mm_data: MultiModalDataDict,
        mm_uuids: MultiModalUUIDDict | None,
590
591
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None,
592
    ) -> "MultiModalInput":
593
        mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
594
595
596

        mm_processor = self.get_mm_processor()

597
598
599
        mm_data_items = mm_processor.info.parse_mm_data(mm_data)
        mm_uuid_items = parse_mm_uuids(mm_uuids)

600
        mm_uuid_items = self._process_mm_uuids(
601
602
            mm_data, mm_data_items, mm_uuid_items, mm_req_id
        )
603

604
605
606
607
608
609
610
611
612
613
614
        mm_processor_inputs = MMProcessorInputs(
            prompt,
            mm_data_items,
            mm_uuid_items,
            hf_processor_mm_kwargs=mm_processor_kwargs or {},
            tokenization_kwargs=tokenization_kwargs or {},
        )
        mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)

        with set_default_torch_num_threads():
            mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
615
616
617
618
619
620
621
622

        self.update_mm_cache_stats()

        return mm_inputs

    def _process_tokens(
        self,
        prompt: TokensPrompt,
623
    ) -> TokensInput | MultiModalInput:
624
625
        prompt_token_ids = prompt["prompt_token_ids"]

626
        engine_input: TokensInput | MultiModalInput
627
        if multi_modal_data := prompt.get("multi_modal_data"):
628
            engine_input = self._process_multimodal(
629
630
631
632
633
634
635
                prompt_token_ids,
                multi_modal_data,
                mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
                tokenization_kwargs=None,  # Tokenization already done in Step 2
                mm_uuids=prompt.get("multi_modal_uuids"),
            )
        else:
636
            engine_input = tokens_input(prompt_token_ids)
637
638

        if prompt_text := prompt.get("prompt"):
639
            engine_input["prompt"] = prompt_text
640
        if cache_salt := prompt.get("cache_salt"):
641
            engine_input["cache_salt"] = cache_salt
642

643
        return engine_input
644

645
    def _process_embeds(self, prompt: EmbedsPrompt) -> EmbedsInput:
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        if not self.model_config.enable_prompt_embeds:
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )

        prompt_embeds = prompt["prompt_embeds"]

        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")

        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

668
        return embeds_input(
669
670
671
672
            prompt_embeds=prompt_embeds,
            cache_salt=prompt.get("cache_salt"),
        )

673
    def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
674
675
676
677
678
679
680
681
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]

        return self._process_tokens(prompt)  # type: ignore[arg-type]

    def _process_enc_dec(
        self,
        prompt: EncoderDecoderTokPrompt,
682
    ) -> EncoderDecoderInput:
683
684
685
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

Ekagra Ranjan's avatar
Ekagra Ranjan committed
686
687
688
689
690
691
692
        skip_decoder_start_token = False
        if self.mm_processor is not None:
            from vllm.multimodal.processing import EncDecMultiModalProcessor

            if isinstance(self.mm_processor, EncDecMultiModalProcessor):
                skip_decoder_start_token = self.mm_processor.skip_decoder_start_token

693
694
695
        return build_enc_dec_input(
            encoder_input=self._process_singleton(enc_prompt),
            decoder_input=(
696
697
698
                None if dec_prompt is None else self._process_singleton(dec_prompt)
            ),
            decoder_start_token_id=self.get_dec_start_token_id(),
Ekagra Ranjan's avatar
Ekagra Ranjan committed
699
            skip_decoder_start_token=skip_decoder_start_token,
700
701
        )

702
703
    def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
        engine_input: EngineInput
704
        if "encoder_prompt" in prompt:
705
            engine_input = self._process_enc_dec(prompt)  # type: ignore[arg-type]
706
        else:
707
            engine_input = self._process_singleton(prompt)
708

709
        engine_input["arrival_time"] = arrival_time
710

711
        return engine_input
712

713
714
715
716
    # Top-level methods
    def render_cmpl(
        self,
        prompts: Sequence[DictPrompt | bytes],
717
        tok_params: TokenizeParams | None = None,
718
719
720
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
721
722
        arrival_time = time.time()

723
724
        if tok_params is None:
            tok_params = self.default_cmpl_tok_params
725

726
        dict_prompts = self.render_prompts(prompts)
727
728
729
730
        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

731
        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
732
733
734
735

    async def render_cmpl_async(
        self,
        prompts: Sequence[DictPrompt | bytes],
736
        tok_params: TokenizeParams | None = None,
737
738
739
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
740
741
        arrival_time = time.time()

742
743
        if tok_params is None:
            tok_params = self.default_cmpl_tok_params
744

745
        dict_prompts = await self.render_prompts_async(prompts)
746
747
748
749
        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

750
        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
751
752
753
754
755

    def render_chat(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
756
        tok_params: TokenizeParams | None = None,
757
758
759
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
760
761
        arrival_time = time.time()

762
763
764
        if tok_params is None:
            tok_params = self.default_chat_tok_params

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        rendered = [
            self.render_messages(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in rendered:
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

780
781
782
783
784
        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts
785
786
787
788
789

    async def render_chat_async(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
790
        tok_params: TokenizeParams | None = None,
791
792
793
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
794
795
        arrival_time = time.time()

796
797
798
        if tok_params is None:
            tok_params = self.default_chat_tok_params

799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        rendered = [
            self.render_messages_async(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in await asyncio.gather(*rendered):
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

814
815
816
817
818
        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts