base.py 25.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import time
5
from abc import ABC, abstractmethod
6
from collections.abc import Mapping, Sequence
7
8
9
10
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload

from typing_extensions import TypeVar
11

12
13
14
15
16
17
18
19
20
21
22
from vllm.inputs import (
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
    ProcessorInputs,
    SingletonInputs,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
23
from vllm.logger import init_logger
24
from vllm.tokenizers import TokenizerLike
25
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
26
from vllm.utils.counter import AtomicCounter
27
28
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
29
30

from .embed_utils import safe_load_prompt_embeds
31
32
33
34
from .inputs import (
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDecoderTokPrompt,
35
36
    SingletonDictPrompt,
    SingletonTokPrompt,
37
38
    TokPrompt,
)
39
from .inputs.preprocess import extract_target_prompt
40
from .params import ChatParams, TokenizeParams
41
42

if TYPE_CHECKING:
43
    from vllm.config import VllmConfig
44
45
46
47
    from vllm.entrypoints.chat_utils import (
        ChatCompletionMessageParam,
        ConversationMessage,
    )
48
    from vllm.multimodal.cache import BaseMultiModalProcessorCache
49
50
51
52
53
54
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
        MultiModalInputs,
        MultiModalUUIDDict,
    )
    from vllm.multimodal.parse import MultiModalDataItems
55
    from vllm.multimodal.processing import BaseMultiModalProcessor
56

57
58
logger = init_logger(__name__)

59

60
61
62
63
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)


class BaseRenderer(ABC, Generic[_T]):
64
    @classmethod
65
    @abstractmethod
66
67
    def from_config(
        cls,
68
        config: "VllmConfig",
69
        tokenizer_kwargs: dict[str, Any],
70
    ) -> "BaseRenderer":
71
72
        raise NotImplementedError

73
    def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
74
75
        super().__init__()

76
        self.config = config
77
        self.model_config = config.model_config
78

79
80
        self.tokenizer = tokenizer

81
82
83
        # Lazy initialization since offline LLM doesn't use async
        self._async_tokenizer: AsyncMicrobatchTokenizer | None = None

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        self.mm_processor: BaseMultiModalProcessor | None = None
        self._mm_cache_stats: MultiModalCacheStats | None = None
        if config.model_config.is_multimodal_model:
            from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry

            mm_processor_cache = mm_registry.processor_cache_from_config(config)

            with set_default_torch_num_threads():
                self.mm_processor = mm_registry.create_processor(
                    config.model_config,
                    config.observability_config,
                    tokenizer=tokenizer,
                    cache=mm_processor_cache,
                )

            if mm_processor_cache:
                self._mm_cache_stats = MultiModalCacheStats()
101

102
103
104
105
            # This is used to generate internal request ID for MM processing
            # It has no relation to the request ID for engine core
            self._mm_req_counter = AtomicCounter()

106
    def get_tokenizer(self) -> _T:
107
108
109
110
111
112
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

113
    def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
114
        if self._async_tokenizer is None:
115
116
117
118
            self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())

        return self._async_tokenizer

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def get_mm_processor(self) -> "BaseMultiModalProcessor":
        if self.mm_processor is None:
            raise ValueError("Multi-modal processor not available for text-only models")

        return self.mm_processor

    @property
    def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
        if self.mm_processor is None:
            return None

        return self.mm_processor.cache

    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        mm_cache_stats = self._mm_cache_stats
        if mm_cache_stats is None:
            return None

        self._mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def update_mm_cache_stats(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        mm_cache_stats = self._mm_cache_stats

        if mm_processor_cache and mm_cache_stats:
            delta = mm_processor_cache.make_stats(delta=True)
            mm_cache_stats.record(delta.total, delta.hits)

    def clear_mm_cache(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.clear_cache()

        if self._mm_cache_stats is not None:
            self._mm_cache_stats.reset = True

    def shutdown(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.close()

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    def get_bos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for BOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.bos_token_id

    def get_eos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for EOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.eos_token_id

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def get_dec_start_token_id(self) -> int:
        """
        Obtain the decoder start token id employed by an encoder/decoder model,
        raising an error if it is not available.
        """
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )

        if dec_start_token_id is None:
            logger.warning_once(
                "Falling back on <BOS> for decoder start token id "
                "because decoder start token id is not available."
            )
            dec_start_token_id = self.get_bos_token_id()

        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")

        return dec_start_token_id

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    @cached_property
    def default_cmpl_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=True,
        )

    @cached_property
    def default_chat_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        )

231
    # Step 1: Convert raw inputs to prompts
232
    def render_prompt(
233
        self,
234
235
236
        prompt: DictPrompt | bytes,
    ) -> DictPrompt:
        if isinstance(prompt, bytes):
237
            embeds = safe_load_prompt_embeds(self.model_config, prompt)
238
            prompt = EmbedsPrompt(prompt_embeds=embeds)
239

240
        return prompt
241

242
    def render_prompts(
243
        self,
244
245
246
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        if len(prompts) == 0:
247
248
            raise ValueError("You must pass at least one prompt")

249
        return [self.render_prompt(prompt) for prompt in prompts]
250

251
    async def render_prompts_async(
252
        self,
253
254
255
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        return self.render_prompts(prompts)
256

257
    @abstractmethod
258
259
260
    def render_messages(
        self,
        messages: list["ChatCompletionMessageParam"],
261
        params: ChatParams,
262
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
263
264
265
266
267
        raise NotImplementedError

    async def render_messages_async(
        self,
        messages: list["ChatCompletionMessageParam"],
268
        params: ChatParams,
269
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
270
271
272
        return self.render_messages(messages, params)

    # Step 2: Tokenize prompts if necessary
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    def _tokenize_prompt(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt_token_ids = tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    async def _tokenize_prompt_async(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt_token_ids = await tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

    async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    @overload
    def _tokenize_singleton_prompt(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    def _tokenize_singleton_prompt(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    def _tokenize_singleton_prompt(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = self._tokenize_prompt(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = self._detokenize_prompt(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

    @overload
    async def _tokenize_singleton_prompt_async(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    async def _tokenize_singleton_prompt_async(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    async def _tokenize_singleton_prompt_async(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = await self._tokenize_prompt_async(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = await self._detokenize_prompt_async(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

373
374
375
376
377
378
    def _tokenize_enc_dec_prompt(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = (
379
            self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
380
381
382
            (
                None
                if prompt["decoder_prompt"] is None
383
                else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    async def _tokenize_enc_dec_prompt_async(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = await asyncio.gather(
398
            self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
399
400
401
            (
                asyncio.sleep(0)
                if prompt["decoder_prompt"] is None
402
403
404
                else self._tokenize_singleton_prompt_async(
                    prompt["decoder_prompt"], params
                )
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    def tokenize_prompt(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return self._tokenize_enc_dec_prompt(prompt, params)  # type: ignore[arg-type]

421
        return self._tokenize_singleton_prompt(prompt, params)
422
423
424

    def tokenize_prompts(
        self,
425
        prompts: Sequence[DictPrompt],
426
        params: TokenizeParams,
427
    ) -> list[TokPrompt]:
428
429
        return [self.tokenize_prompt(prompt, params) for prompt in prompts]

430
431
432
433
434
435
436
    async def tokenize_prompt_async(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return await self._tokenize_enc_dec_prompt_async(prompt, params)  # type: ignore[arg-type]
437

438
        return await self._tokenize_singleton_prompt_async(prompt, params)
439
440
441

    async def tokenize_prompts_async(
        self,
442
        prompts: Sequence[DictPrompt],
443
        params: TokenizeParams,
444
    ) -> list[TokPrompt]:
445
446
447
        return await asyncio.gather(
            *(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
        )
448
449
450
451

    # Step 3: Add extra keys to the prompts
    def _apply_prompt_extras(
        self,
452
        prompts: Sequence[TokPrompt],
453
454
455
456
457
458
        prompt_extras: dict[str, Any] | None,
    ):
        if not prompt_extras:
            return

        for prompt in prompts:
459
            target_prompt = extract_target_prompt(self.model_config, prompt)
460
461
            target_prompt.update(prompt_extras)  # type: ignore[arg-type]

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    # Step 4: Convert to engine inputs
    def _validate_mm_uuids(
        self,
        mm_data: "MultiModalDataDict",
        mm_items: "MultiModalDataItems",
        mm_uuids: "MultiModalUUIDDict | None",
    ) -> None:
        if mm_uuids is None:
            mm_uuids = {}

        # NOTE: Keys corresponding to `None` in `mm_data` don't appear in `mm_items`
        modalities = mm_data.keys() | mm_uuids.keys()

        for modality in modalities:
            data_items = mm_items.get(modality) or list[Any]()

            uuid_items = mm_uuids.get(modality) or list[str | None]()
            if isinstance(uuid_items, str):
                uuid_items = [uuid_items]

            if len(data_items) > 0:
                if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
                    raise ValueError(
                        f"If given, multi_modal_uuids[{modality!r}] must have "
                        f"same length as multi_modal_data[{modality!r}], but "
                        f"got {len(uuid_items)} vs {len(data_items)}."
                    )

                for i, item in enumerate(data_items):
                    if item is None:
                        if not uuid_items:
                            raise ValueError(
                                f"multi_modal_data[{modality!r}][{i}] is empty but "
                                f"multi_modal_uuids[{modality!r}] is missing."
                            )

                        if uuid_items[i] is None:
                            raise ValueError(
                                f"multi_modal_data[{modality!r}][{i}] is empty but "
                                f"multi_modal_uuids[{modality!r}][{i}] is missing."
                            )

    def _process_mm_uuids(
        self,
        mm_data: "MultiModalDataDict",
        mm_items: "MultiModalDataItems",
        mm_uuids: "MultiModalUUIDDict | None",
        mm_req_id: str,
    ):
        model_config = self.model_config

        # NOTE: When users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
        if (
            model_config.multimodal_config
            and model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.config.cache_config.enable_prefix_caching
        ):
            mm_uuids = {
                modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
                for modality, data_count in mm_items.get_all_counts().items()
            }

        self._validate_mm_uuids(mm_data, mm_items, mm_uuids)

        return mm_uuids

    # TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
    def _process_multimodal(
        self,
        prompt: list[int] | str,
        mm_data: "MultiModalDataDict",
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None,
        mm_uuids: "MultiModalUUIDDict | None",
    ) -> "MultiModalInputs":
        from vllm.multimodal.processing.context import set_request_id

        mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"

        mm_processor = self.get_mm_processor()

        mm_items = mm_processor.info.parse_mm_data(mm_data)
        mm_uuids = self._process_mm_uuids(mm_data, mm_items, mm_uuids, mm_req_id)

        with set_request_id(mm_req_id), set_default_torch_num_threads():
            mm_inputs = mm_processor.apply(
                prompt,
                mm_items,
                hf_processor_mm_kwargs=mm_processor_kwargs or {},
                tokenization_kwargs=tokenization_kwargs,
                mm_uuids=mm_uuids,
            )

        self.update_mm_cache_stats()

        return mm_inputs

    def _process_tokens(
        self,
        prompt: TokensPrompt,
    ) -> "TokenInputs | MultiModalInputs":
        prompt_token_ids = prompt["prompt_token_ids"]

        inputs: TokenInputs | MultiModalInputs
        if multi_modal_data := prompt.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
                tokenization_kwargs=None,  # Tokenization already done in Step 2
                mm_uuids=prompt.get("multi_modal_uuids"),
            )
        else:
            inputs = token_inputs(prompt_token_ids)

        if prompt_text := prompt.get("prompt"):
            inputs["prompt"] = prompt_text
        if cache_salt := prompt.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_embeds(
        self,
        prompt: EmbedsPrompt,
    ) -> EmbedsInputs:
        if not self.model_config.enable_prompt_embeds:
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )

        prompt_embeds = prompt["prompt_embeds"]

        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")

        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

        return embeds_inputs(
            prompt_embeds=prompt_embeds,
            cache_salt=prompt.get("cache_salt"),
        )

    def _process_singleton(
        self,
        prompt: SingletonTokPrompt,
    ) -> SingletonInputs:
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]

        return self._process_tokens(prompt)  # type: ignore[arg-type]

    def _process_enc_dec(
        self,
        prompt: EncoderDecoderTokPrompt,
    ) -> EncoderDecoderInputs:
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

        return build_enc_dec_inputs(
            encoder_inputs=self._process_singleton(enc_prompt),
            decoder_inputs=(
                None if dec_prompt is None else self._process_singleton(dec_prompt)
            ),
            decoder_start_token_id=self.get_dec_start_token_id(),
        )

    def process_for_engine(
        self, prompt: TokPrompt, arrival_time: float
    ) -> ProcessorInputs:
        engine_prompt: ProcessorInputs
        if "encoder_prompt" in prompt:
            engine_prompt = self._process_enc_dec(prompt)  # type: ignore[arg-type]
        else:
            engine_prompt = self._process_singleton(prompt)

        engine_prompt["arrival_time"] = arrival_time

        return engine_prompt

656
657
658
659
    # Top-level methods
    def render_cmpl(
        self,
        prompts: Sequence[DictPrompt | bytes],
660
        tok_params: TokenizeParams | None = None,
661
662
663
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
664
665
        arrival_time = time.time()

666
667
        if tok_params is None:
            tok_params = self.default_cmpl_tok_params
668

669
        dict_prompts = self.render_prompts(prompts)
670
671
672
673
        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

674
        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
675
676
677
678

    async def render_cmpl_async(
        self,
        prompts: Sequence[DictPrompt | bytes],
679
        tok_params: TokenizeParams | None = None,
680
681
682
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
683
684
        arrival_time = time.time()

685
686
        if tok_params is None:
            tok_params = self.default_cmpl_tok_params
687

688
        dict_prompts = await self.render_prompts_async(prompts)
689
690
691
692
        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

693
        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
694
695
696
697
698

    def render_chat(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
699
        tok_params: TokenizeParams | None = None,
700
701
702
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
703
704
        arrival_time = time.time()

705
706
707
        if tok_params is None:
            tok_params = self.default_chat_tok_params

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
        rendered = [
            self.render_messages(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in rendered:
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

723
724
725
726
727
        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts
728
729
730
731
732

    async def render_chat_async(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
733
        tok_params: TokenizeParams | None = None,
734
735
736
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
737
738
        arrival_time = time.time()

739
740
741
        if tok_params is None:
            tok_params = self.default_chat_tok_params

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        rendered = [
            self.render_messages_async(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in await asyncio.gather(*rendered):
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

757
758
759
760
761
        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts