base.py 32 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import copy
5
import time
6
from abc import ABC, abstractmethod
7
from collections.abc import Mapping, Sequence
8
from concurrent.futures import Executor, ThreadPoolExecutor
9
10
11
12
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload

from typing_extensions import TypeVar
13

14
from vllm.inputs import (
15
    EmbedsInput,
16
    EmbedsPrompt,
17
18
19
20
21
22
    EncoderDecoderInput,
    EngineInput,
    MultiModalDataDict,
    MultiModalInput,
    MultiModalUUIDDict,
    SingletonInput,
23
    TextPrompt,
24
    TokensInput,
25
    TokensPrompt,
26
27
28
    build_enc_dec_input,
    embeds_input,
    tokens_input,
29
)
30
from vllm.logger import init_logger
31
32
33
34
35
36
37
38
39
40
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.parse import (
    MultiModalDataItems,
    MultiModalUUIDItems,
    parse_mm_uuids,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
from vllm.multimodal.registry import MultiModalTimingRegistry
41
from vllm.tokenizers import TokenizerLike
42
43
44
45
from vllm.utils.async_utils import (
    AsyncMicrobatchTokenizer,
    make_async,
)
46
from vllm.utils.counter import AtomicCounter
47
48
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
49
50

from .embed_utils import safe_load_prompt_embeds
51
52
53
54
from .inputs import (
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDecoderTokPrompt,
55
56
    SingletonDictPrompt,
    SingletonTokPrompt,
57
58
    TokPrompt,
)
59
from .inputs.preprocess import extract_target_prompt
60
from .params import ChatParams, TokenizeParams
61
62

if TYPE_CHECKING:
63
    from vllm.config import VllmConfig
64
65
66
67
68
    from vllm.entrypoints.chat_utils import (
        ChatCompletionMessageParam,
        ConversationMessage,
    )

69
70
logger = init_logger(__name__)

71

72
73
74
75
76
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)


class BaseRenderer(ABC, Generic[_T]):
    def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
77
78
        super().__init__()

79
        self.config = config
80
        self.model_config = config.model_config
81
        self.api_process_rank = config.parallel_config._api_process_rank
82

83
84
        self.tokenizer = tokenizer

85
86
87
88
89
90
91
92
93
94
95
        # Shared thread pool executor for blocking tokenizer and
        # multimodal preprocessing operations.  The multimodal processor
        # receives a deep-copied tokenizer (see #36557) so it is safe to
        # run tokenization and MM preprocessing concurrently.
        pool_workers = config.model_config.renderer_num_workers
        self._executor = ThreadPoolExecutor(max_workers=pool_workers)

        # Multimodal preprocessing is always offloaded to the thread pool
        # to keep the asyncio event loop responsive under concurrent load.
        self._mm_executor: Executor = self._executor

96
97
98
        # Lazy initialization since offline LLM doesn't use async
        self._async_tokenizer: AsyncMicrobatchTokenizer | None = None

99
100
        self.mm_processor: BaseMultiModalProcessor | None = None
        self._mm_cache_stats: MultiModalCacheStats | None = None
101
102
103
104
105
106
        self._clear_mm_cache_async = make_async(
            self.clear_mm_cache, executor=self._executor
        )
        self._process_multimodal_async = make_async(
            self._process_multimodal, executor=self._mm_executor
        )
107
108
109
        if config.model_config.is_multimodal_model:
            mm_processor_cache = mm_registry.processor_cache_from_config(config)

110
111
112
113
114
115
116
            # Deep-copy the tokenizer so the multimodal processor gets its
            # own Rust tokenizer backend.  Without this, concurrent access
            # from AsyncMicrobatchTokenizer and call_hf_processor causes
            # "RuntimeError: Already borrowed" from the Rust RefCell.
            # See: https://github.com/huggingface/tokenizers/issues/537
            mm_tokenizer = copy.deepcopy(tokenizer)

117
118
119
            with set_default_torch_num_threads():
                self.mm_processor = mm_registry.create_processor(
                    config.model_config,
120
                    tokenizer=mm_tokenizer,
121
122
123
124
125
                    cache=mm_processor_cache,
                )

            if mm_processor_cache:
                self._mm_cache_stats = MultiModalCacheStats()
126

127
128
129
            # This is used to generate internal request ID for MM processing
            # It has no relation to the request ID for engine core
            self._mm_req_counter = AtomicCounter()
130
131
132
            self._mm_timing_registry = MultiModalTimingRegistry(
                config.observability_config
            )
133

134
    def get_tokenizer(self) -> _T:
135
136
137
138
139
140
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

141
    def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
142
        if self._async_tokenizer is None:
143
144
145
            self._async_tokenizer = AsyncMicrobatchTokenizer(
                self.get_tokenizer(), executor=self._executor
            )
146
147
148

        return self._async_tokenizer

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def get_mm_processor(self) -> "BaseMultiModalProcessor":
        if self.mm_processor is None:
            raise ValueError("Multi-modal processor not available for text-only models")

        return self.mm_processor

    @property
    def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
        if self.mm_processor is None:
            return None

        return self.mm_processor.cache

    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        mm_cache_stats = self._mm_cache_stats
        if mm_cache_stats is None:
            return None

        self._mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def update_mm_cache_stats(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        mm_cache_stats = self._mm_cache_stats

        if mm_processor_cache and mm_cache_stats:
            delta = mm_processor_cache.make_stats(delta=True)
            mm_cache_stats.record(delta.total, delta.hits)

    def clear_mm_cache(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.clear_cache()

        if self._mm_cache_stats is not None:
            self._mm_cache_stats.reset = True

187
188
189
190
191
192
193
    def warmup(self, chat_params: ChatParams) -> None:
        """
        Warm up this renderer to avoid first-request latency.

        For chat requests:
        - Jinja2 template compilation
        """
194
195
        from vllm.entrypoints.chat_utils import ChatTemplateResolutionError

196
        try:
197
            logger.debug("Warming up chat template processing...")
198
199
200
201
202
            start_time = time.perf_counter()

            self.render_chat([[{"role": "user", "content": "warmup"}]], chat_params)

            elapsed = time.perf_counter() - start_time
203
            logger.debug("Chat template warmup completed in %.3fs", elapsed)
204
        except ChatTemplateResolutionError:
205
            logger.debug("This model does not support chat template.")
206
        except Exception:
207
            logger.warning("Chat template warmup failed", exc_info=True)
208
209
210
211
212
213
214
215
216
217

        if self.mm_processor:
            from vllm.multimodal.processing import TimingContext

            model_config = self.model_config
            mm_config = model_config.get_multimodal_config()
            processor = self.mm_processor
            mm_limits = processor.info.allowed_mm_limits

            try:
218
                logger.debug("Warming up multi-modal processing...")
219
220
221
222
223
224
225
226
                start_time = time.perf_counter()

                processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
                    seq_len=model_config.max_model_len,
                    mm_counts=dict.fromkeys(mm_limits, 1),
                    mm_options=mm_config.limit_per_prompt,
                )
                _ = processor.apply(
227
                    processor_inputs, timing_ctx=TimingContext(enabled=False)
228
229
230
231
232
                )

                elapsed = time.perf_counter() - start_time
                logger.info("Multi-modal warmup completed in %.3fs", elapsed)
            except Exception:
233
                logger.warning("Multi-modal warmup failed")
234
235
236
            finally:
                self.clear_mm_cache()

237
238
239
240
241
    async def clear_mm_cache_async(self) -> None:
        """Serialize clear_mm_cache through the shared executor to avoid
        races with concurrent process_inputs on the mm_processor_cache."""
        await self._clear_mm_cache_async()

242
243
244
245
246
    def shutdown(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.close()

247
248
249
250
251
252
253
254
        if executor := getattr(self, "_executor", None):
            executor.shutdown(wait=False)

        if (
            mm_executor := getattr(self, "_mm_executor", None)
        ) is not None and mm_executor is not executor:
            mm_executor.shutdown(wait=False)

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def get_bos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for BOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.bos_token_id

    def get_eos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for EOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.eos_token_id

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    def get_dec_start_token_id(self) -> int:
        """
        Obtain the decoder start token id employed by an encoder/decoder model,
        raising an error if it is not available.
        """
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )

        if dec_start_token_id is None:
            logger.warning_once(
                "Falling back on <BOS> for decoder start token id "
                "because decoder start token id is not available."
            )
            dec_start_token_id = self.get_bos_token_id()

        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")

        return dec_start_token_id

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    @cached_property
    def default_cmpl_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=True,
        )

    @cached_property
    def default_chat_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        )

324
    # Step 1: Convert raw inputs to prompts
325
    def render_prompt(
326
        self,
327
328
329
        prompt: DictPrompt | bytes,
    ) -> DictPrompt:
        if isinstance(prompt, bytes):
330
            embeds = safe_load_prompt_embeds(self.model_config, prompt)
331
            prompt = EmbedsPrompt(prompt_embeds=embeds)
332

333
        return prompt
334

335
    def render_prompts(
336
        self,
337
338
339
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        if len(prompts) == 0:
340
341
            raise ValueError("You must pass at least one prompt")

342
        return [self.render_prompt(prompt) for prompt in prompts]
343

344
    async def render_prompts_async(
345
        self,
346
347
348
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        return self.render_prompts(prompts)
349

350
    @abstractmethod
351
352
353
    def render_messages(
        self,
        messages: list["ChatCompletionMessageParam"],
354
        params: ChatParams,
355
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
356
357
358
359
360
        raise NotImplementedError

    async def render_messages_async(
        self,
        messages: list["ChatCompletionMessageParam"],
361
        params: ChatParams,
362
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
363
364
365
        return self.render_messages(messages, params)

    # Step 2: Tokenize prompts if necessary
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    def _tokenize_prompt(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt_token_ids = tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    async def _tokenize_prompt_async(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt_token_ids = await tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

    async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    @overload
    def _tokenize_singleton_prompt(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    def _tokenize_singleton_prompt(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    def _tokenize_singleton_prompt(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = self._tokenize_prompt(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = self._detokenize_prompt(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

    @overload
    async def _tokenize_singleton_prompt_async(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    async def _tokenize_singleton_prompt_async(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    async def _tokenize_singleton_prompt_async(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = await self._tokenize_prompt_async(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = await self._detokenize_prompt_async(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

466
467
468
469
470
471
    def _tokenize_enc_dec_prompt(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = (
472
            self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
473
474
475
            (
                None
                if prompt["decoder_prompt"] is None
476
                else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
477
478
479
480
481
482
483
484
485
486
487
488
489
490
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    async def _tokenize_enc_dec_prompt_async(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = await asyncio.gather(
491
            self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
492
493
494
            (
                asyncio.sleep(0)
                if prompt["decoder_prompt"] is None
495
496
497
                else self._tokenize_singleton_prompt_async(
                    prompt["decoder_prompt"], params
                )
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    def tokenize_prompt(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return self._tokenize_enc_dec_prompt(prompt, params)  # type: ignore[arg-type]

514
        return self._tokenize_singleton_prompt(prompt, params)
515
516
517

    def tokenize_prompts(
        self,
518
        prompts: Sequence[DictPrompt],
519
        params: TokenizeParams,
520
    ) -> list[TokPrompt]:
521
522
        return [self.tokenize_prompt(prompt, params) for prompt in prompts]

523
524
525
526
527
528
529
    async def tokenize_prompt_async(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return await self._tokenize_enc_dec_prompt_async(prompt, params)  # type: ignore[arg-type]
530

531
        return await self._tokenize_singleton_prompt_async(prompt, params)
532
533
534

    async def tokenize_prompts_async(
        self,
535
        prompts: Sequence[DictPrompt],
536
        params: TokenizeParams,
537
    ) -> list[TokPrompt]:
538
539
540
        return await asyncio.gather(
            *(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
        )
541
542
543
544

    # Step 3: Add extra keys to the prompts
    def _apply_prompt_extras(
        self,
545
        prompts: Sequence[TokPrompt],
546
547
548
549
550
551
        prompt_extras: dict[str, Any] | None,
    ):
        if not prompt_extras:
            return

        for prompt in prompts:
552
            target_prompt = extract_target_prompt(self.model_config, prompt)
553
554
            target_prompt.update(prompt_extras)  # type: ignore[arg-type]

555
556
557
    # Step 4: Convert to engine inputs
    def _validate_mm_uuids(
        self,
558
559
560
        mm_data: MultiModalDataDict,
        mm_data_items: MultiModalDataItems,
        mm_uuid_items: MultiModalUUIDItems,
561
    ) -> None:
562
563
564
        # NOTE: Keys corresponding to `None` in `mm_data` don't appear in
        # `mm_data_items`
        modalities = mm_data.keys() | mm_uuid_items.keys()
565
566

        for modality in modalities:
567
568
            data_items = mm_data_items.get(modality)
            uuid_items = mm_uuid_items.get(modality)
569

570
571
572
573
574
575
            if data_items is None:
                if uuid_items is None:
                    raise ValueError(
                        f"multi_modal_data[{modality!r}] is empty but "
                        f"multi_modal_uuids[{modality!r}] is missing."
                    )
576

577
            elif uuid_items is not None:
578
                if len(data_items) != len(uuid_items):
579
580
581
582
583
584
585
                    raise ValueError(
                        f"If given, multi_modal_uuids[{modality!r}] must have "
                        f"same length as multi_modal_data[{modality!r}], but "
                        f"got {len(uuid_items)} vs {len(data_items)}."
                    )

                for i, item in enumerate(data_items):
586
587
588
589
590
                    if item is None and uuid_items[i] is None:
                        raise ValueError(
                            f"multi_modal_data[{modality!r}][{i}] is empty but "
                            f"multi_modal_uuids[{modality!r}][{i}] is missing."
                        )
591
592
593

    def _process_mm_uuids(
        self,
594
595
596
        mm_data: MultiModalDataDict,
        mm_data_items: MultiModalDataItems,
        mm_uuid_items: MultiModalUUIDItems,
597
        mm_req_id: str,
598
    ) -> MultiModalUUIDItems:
599
600
601
602
603
604
605
606
607
608
609
610
        model_config = self.model_config

        # NOTE: When users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
        if (
            model_config.multimodal_config
            and model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.config.cache_config.enable_prefix_caching
        ):
611
            mm_uuid_items = {
612
                modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
613
                for modality, data_count in mm_data_items.get_all_counts().items()
614
615
            }

616
        self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
617

618
        return mm_uuid_items
619
620
621
622
623

    # TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
    def _process_multimodal(
        self,
        prompt: list[int] | str,
624
625
        mm_data: MultiModalDataDict,
        mm_uuids: MultiModalUUIDDict | None,
626
627
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None,
628
    ) -> "MultiModalInput":
629
        mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
630
631
632

        mm_processor = self.get_mm_processor()

633
634
635
        mm_data_items = mm_processor.info.parse_mm_data(mm_data)
        mm_uuid_items = parse_mm_uuids(mm_uuids)

636
        mm_uuid_items = self._process_mm_uuids(
637
638
            mm_data, mm_data_items, mm_uuid_items, mm_req_id
        )
639

640
641
642
643
644
645
646
647
648
649
650
        mm_processor_inputs = MMProcessorInputs(
            prompt,
            mm_data_items,
            mm_uuid_items,
            hf_processor_mm_kwargs=mm_processor_kwargs or {},
            tokenization_kwargs=tokenization_kwargs or {},
        )
        mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)

        with set_default_torch_num_threads():
            mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
651
652
653
654
655
656
657
658

        self.update_mm_cache_stats()

        return mm_inputs

    def _process_tokens(
        self,
        prompt: TokensPrompt,
659
    ) -> TokensInput | MultiModalInput:
660
661
662
        """Process token inputs, with multimodal preprocessing offloaded
        to the shared thread pool in the async variant.
        """
663
664
        prompt_token_ids = prompt["prompt_token_ids"]

665
        engine_input: TokensInput | MultiModalInput
666
        if multi_modal_data := prompt.get("multi_modal_data"):
667
            engine_input = self._process_multimodal(
668
669
670
671
672
673
674
                prompt_token_ids,
                multi_modal_data,
                mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
                tokenization_kwargs=None,  # Tokenization already done in Step 2
                mm_uuids=prompt.get("multi_modal_uuids"),
            )
        else:
675
            engine_input = tokens_input(prompt_token_ids)
676
677

        if prompt_text := prompt.get("prompt"):
678
            engine_input["prompt"] = prompt_text
679
        if cache_salt := prompt.get("cache_salt"):
680
            engine_input["cache_salt"] = cache_salt
681

682
        return engine_input
683

684
    def _process_embeds(self, prompt: EmbedsPrompt) -> EmbedsInput:
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        if not self.model_config.enable_prompt_embeds:
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )

        prompt_embeds = prompt["prompt_embeds"]

        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")

        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

707
        return embeds_input(
708
709
710
711
            prompt_embeds=prompt_embeds,
            cache_salt=prompt.get("cache_salt"),
        )

712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
    async def _process_tokens_async(
        self,
        prompt: TokensPrompt,
    ) -> TokensInput | MultiModalInput:
        prompt_token_ids = prompt["prompt_token_ids"]

        engine_input: TokensInput | MultiModalInput
        if multi_modal_data := prompt.get("multi_modal_data"):
            engine_input = await self._process_multimodal_async(
                prompt_token_ids,
                multi_modal_data,
                mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
                tokenization_kwargs=None,
                mm_uuids=prompt.get("multi_modal_uuids"),
            )
        else:
            engine_input = tokens_input(prompt_token_ids)

        if prompt_text := prompt.get("prompt"):
            engine_input["prompt"] = prompt_text
        if cache_salt := prompt.get("cache_salt"):
            engine_input["cache_salt"] = cache_salt

        return engine_input

737
    def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
738
739
740
741
742
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]

        return self._process_tokens(prompt)  # type: ignore[arg-type]

743
744
745
746
747
748
749
750
751
    async def _process_singleton_async(
        self,
        prompt: SingletonTokPrompt,
    ) -> SingletonInput:
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]

        return await self._process_tokens_async(prompt)  # type: ignore[arg-type]

752
753
754
    def _process_enc_dec(
        self,
        prompt: EncoderDecoderTokPrompt,
755
    ) -> EncoderDecoderInput:
756
757
758
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

Ekagra Ranjan's avatar
Ekagra Ranjan committed
759
760
761
762
763
764
765
        skip_decoder_start_token = False
        if self.mm_processor is not None:
            from vllm.multimodal.processing import EncDecMultiModalProcessor

            if isinstance(self.mm_processor, EncDecMultiModalProcessor):
                skip_decoder_start_token = self.mm_processor.skip_decoder_start_token

766
767
768
        return build_enc_dec_input(
            encoder_input=self._process_singleton(enc_prompt),
            decoder_input=(
769
770
771
                None if dec_prompt is None else self._process_singleton(dec_prompt)
            ),
            decoder_start_token_id=self.get_dec_start_token_id(),
Ekagra Ranjan's avatar
Ekagra Ranjan committed
772
            skip_decoder_start_token=skip_decoder_start_token,
773
774
        )

775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    async def _process_enc_dec_async(
        self,
        prompt: EncoderDecoderTokPrompt,
    ) -> EncoderDecoderInput:
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

        encoder_input, decoder_input = await asyncio.gather(
            self._process_singleton_async(enc_prompt),
            (
                asyncio.sleep(0)
                if dec_prompt is None
                else self._process_singleton_async(dec_prompt)
            ),
        )

        return build_enc_dec_input(
            encoder_input=encoder_input,
            decoder_input=decoder_input,
            decoder_start_token_id=self.get_dec_start_token_id(),
        )

797
798
    def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
        engine_input: EngineInput
799
        if "encoder_prompt" in prompt:
800
            engine_input = self._process_enc_dec(prompt)  # type: ignore[arg-type]
801
        else:
802
            engine_input = self._process_singleton(prompt)
803

804
        engine_input["arrival_time"] = arrival_time
805

806
        return engine_input
807

808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
    async def process_for_engine_async(
        self, prompt: TokPrompt, arrival_time: float
    ) -> EngineInput:
        engine_input: EngineInput
        if "encoder_prompt" in prompt:
            engine_input = await self._process_enc_dec_async(
                prompt  # type: ignore[arg-type]
            )
        else:
            engine_input = await self._process_singleton_async(prompt)

        engine_input["arrival_time"] = arrival_time

        return engine_input

823
824
825
826
    # Top-level methods
    def render_cmpl(
        self,
        prompts: Sequence[DictPrompt | bytes],
827
        tok_params: TokenizeParams | None = None,
828
829
830
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
831
832
        arrival_time = time.time()

833
834
        if tok_params is None:
            tok_params = self.default_cmpl_tok_params
835

836
        dict_prompts = self.render_prompts(prompts)
837
838
839
840
        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

841
        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
842
843
844
845

    async def render_cmpl_async(
        self,
        prompts: Sequence[DictPrompt | bytes],
846
        tok_params: TokenizeParams | None = None,
847
848
849
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
850
851
        arrival_time = time.time()

852
853
        if tok_params is None:
            tok_params = self.default_cmpl_tok_params
854

855
        dict_prompts = await self.render_prompts_async(prompts)
856
857
858
859
        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

860
861
862
        return await asyncio.gather(
            *(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
        )
863
864
865
866
867

    def render_chat(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
868
        tok_params: TokenizeParams | None = None,
869
870
871
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
872
873
        arrival_time = time.time()

874
875
876
        if tok_params is None:
            tok_params = self.default_chat_tok_params

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
        rendered = [
            self.render_messages(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in rendered:
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

892
893
894
895
896
        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts
897
898
899
900
901

    async def render_chat_async(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
902
        tok_params: TokenizeParams | None = None,
903
904
905
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
906
907
        arrival_time = time.time()

908
909
910
        if tok_params is None:
            tok_params = self.default_chat_tok_params

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
        rendered = [
            self.render_messages_async(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in await asyncio.gather(*rendered):
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

926
927
928
        eng_prompts = await asyncio.gather(
            *(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
        )
929
930

        return out_conversations, eng_prompts