serving.py 43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import time
6
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
7
from dataclasses import dataclass, field
8
from http import HTTPStatus
9
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
10

11
import numpy as np
12
from fastapi import Request
13
14
15
from openai.types.responses import (
    ToolChoiceFunction,
)
16
17
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
18

19
import vllm.envs as envs
20
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
21
from vllm.config import ModelConfig
22
from vllm.engine.protocol import EngineClient
23
24
25
26
27
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
28
from vllm.entrypoints.logger import RequestLogger
29
from vllm.entrypoints.openai.chat_completion.protocol import (
30
    ChatCompletionNamedToolChoiceParam,
31
32
    ChatCompletionRequest,
    ChatCompletionResponse,
33
)
34
from vllm.entrypoints.openai.completion.protocol import (
35
36
    CompletionRequest,
    CompletionResponse,
37
38
)
from vllm.entrypoints.openai.engine.protocol import (
39
    ErrorResponse,
40
    FunctionCall,
41
    FunctionDefinition,
42
    GenerationError,
43
)
44
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
45
46
47
48
49
50
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
51
52
53
54
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
55
56
57
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
58
from vllm.entrypoints.openai.speech_to_text.protocol import (
59
60
61
62
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
63
from vllm.entrypoints.pooling.embed.protocol import (
64
    EmbeddingBytesResponse,
65
66
67
68
69
70
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
71
72
    PoolingChatRequest,
    PoolingCompletionRequest,
73
74
75
76
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
77
78
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
79
80
    ScoreRequest,
    ScoreResponse,
81
    ScoreTextRequest,
82
)
83
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
84
85
86
87
88
89
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
90
from vllm.entrypoints.utils import create_error_response, get_max_tokens
91
from vllm.exceptions import VLLMValidationError
92
93
94
95
96
97
98
from vllm.inputs.data import (
    ProcessorInputs,
    PromptType,
    SingletonPrompt,
    TokensPrompt,
    token_inputs,
)
99
from vllm.logger import init_logger
100
from vllm.logprobs import Logprob, PromptLogprobs
101
from vllm.lora.request import LoRARequest
102
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
103
from vllm.pooling_params import PoolingParams
104
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
105
106
107
108
109
110
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
111
from vllm.sampling_params import BeamSearchParams, SamplingParams
112
from vllm.tokenizers import TokenizerLike
113
from vllm.tool_parsers import ToolParser
114
115
116
117
118
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
119
from vllm.utils import random_uuid
120
from vllm.utils.async_utils import (
121
    collect_from_async_generator,
122
123
    merge_async_iterators,
)
124
from vllm.utils.mistral import is_mistral_tokenizer
125
126
127

logger = init_logger(__name__)

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


143
144
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
145
    | TokenizeCompletionRequest
146
147
    | DetokenizeRequest
    | EmbeddingCompletionRequest
148
    | RerankRequest
149
    | ScoreRequest
150
    | PoolingCompletionRequest
151
)
152

153
ChatLikeRequest: TypeAlias = (
154
155
    ChatCompletionRequest
    | TokenizeChatRequest
156
157
    | EmbeddingChatRequest
    | PoolingChatRequest
158
)
159

160
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
161

162
163
164
165
166
167
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
168
    | GenerateRequest
169
170
171
172
173
174
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
175
    | EmbeddingBytesResponse
176
177
178
179
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ScoreResponse
180
    | GenerateResponse
181
)
182
183
184
185

RequestT = TypeVar("RequestT", bound=AnyRequest)


186
@dataclass(kw_only=True)
187
class ServeContext(Generic[RequestT]):
188
    request: RequestT
189
    raw_request: Request | None = None
190
191
    model_name: str
    request_id: str
192
    created_time: int = field(default_factory=lambda: int(time.time()))
193
    lora_request: LoRARequest | None = None
194
    engine_prompts: list[ProcessorInputs] | None = None
195

196
197
198
199
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
200

201
    model_config = ConfigDict(arbitrary_types_allowed=True)
202
203


204
class OpenAIServing:
205
    request_id_prefix: ClassVar[str] = """
206
207
    A short string prepended to every request’s ID (e.g. "embd")
    so you can easily tell “this ID came from Embedding.”
208
    """
209

210
211
    def __init__(
        self,
212
        engine_client: EngineClient,
213
        models: OpenAIServingModels,
214
        *,
215
        request_logger: RequestLogger | None,
216
        return_tokens_as_token_ids: bool = False,
217
    ):
218
219
        super().__init__()

220
        self.engine_client = engine_client
221

222
        self.models = models
223

224
        self.request_logger = request_logger
225
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
226

227
228
229
230
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
231
232
233

    async def beam_search(
        self,
234
        prompt: ProcessorInputs,
235
236
        request_id: str,
        params: BeamSearchParams,
237
        lora_request: LoRARequest | None = None,
238
        trace_headers: Mapping[str, str] | None = None,
239
240
241
242
243
244
245
246
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

247
248
249
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
250

251
252
253
254
255
256
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
        if prompt["type"] == "enc_dec":
            raise NotImplementedError(
                "Encoder-decoder prompt not supported for beam search"
            )
257

258
259
        prompt_text = prompt.get("prompt")
        prompt_token_ids = prompt["prompt_token_ids"]
260
261
        tokenized_length = len(prompt_token_ids)

262
        logprobs_num = 2 * beam_width
263
        sampling_params = SamplingParams(
264
            logprobs=logprobs_num,
265
266
267
268
269
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
270
                orig_prompt=prompt,
271
272
273
274
275
276
277
278
279
280
281
282
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

283
284
285
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
286
287
288
289
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
290
291
                            prompt_item,
                            sampling_params,
292
                            request_id_item,
293
                            lora_request=lora_request_item,
294
                            trace_headers=trace_headers,
295
296
297
298
299
300
301
302
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
303
304
305
306
307
308
309
310
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

334
335
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
358
                            orig_prompt=prompt,
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
386
                        orig_prompt=prompt,
387
388
389
390
391
392
393
394
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
429

430
431
432
    async def _preprocess(
        self,
        ctx: ServeContext,
433
    ) -> ErrorResponse | None:
434
435
        """
        Default preprocessing hook. Subclasses may override
436
        to prepare `ctx` (embedding, etc.).
437
438
439
440
441
442
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
443
    ) -> AnyResponse | ErrorResponse:
444
445
446
447
448
449
450
451
452
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
453
    ) -> AnyResponse | ErrorResponse:
454
        async for response in self._pipeline(ctx):
455
456
457
458
459
460
461
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
462
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

483
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
484
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
485

486
487
        if (
            truncate_prompt_tokens is not None
488
            and truncate_prompt_tokens > self.model_config.max_model_len
489
        ):
490
491
492
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
493
494
                " Please, select a smaller truncation size."
            )
495
496
        return None

497
498
499
    def _create_pooling_params(
        self,
        ctx: ServeContext,
500
    ) -> PoolingParams | ErrorResponse:
501
502
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
503
504
                "Request type does not support pooling parameters"
            )
505
506
507

        return ctx.request.to_pooling_params()

508
509
510
    async def _prepare_generators(
        self,
        ctx: ServeContext,
511
    ) -> ErrorResponse | None:
512
        """Schedule the request and get the result generator."""
513
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
514

515
516
517
518
519
        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )
520

521
522
523
        pooling_params = self._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
524

525
526
        if ctx.engine_prompts is None:
            return self.create_error_response("Engine prompts not available")
527

528
529
        for i, engine_prompt in enumerate(ctx.engine_prompts):
            request_id_item = f"{ctx.request_id}-{i}"
530

531
532
533
534
535
536
            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
537

538
539
540
541
542
543
544
545
            generator = self.engine_client.encode(
                engine_prompt,
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )
546

547
            generators.append(generator)
548

549
        ctx.result_generator = merge_async_iterators(*generators)
550

551
        return None
552
553
554
555

    async def _collect_batch(
        self,
        ctx: ServeContext,
556
    ) -> ErrorResponse | None:
557
        """Collect batch results from the result generator."""
558
559
        if ctx.engine_prompts is None:
            return self.create_error_response("Engine prompts not available")
560

561
562
563
        num_prompts = len(ctx.engine_prompts)
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
564

565
566
        if ctx.result_generator is None:
            return self.create_error_response("Result generator not available")
567

568
569
        async for i, res in ctx.result_generator:
            final_res_batch[i] = res
570

571
572
573
574
        if None in final_res_batch:
            return self.create_error_response(
                "Failed to generate results for all prompts"
            )
575

576
        ctx.final_res_batch = [res for res in final_res_batch if res is not None]
577

578
        return None
579

580
    @staticmethod
581
    def create_error_response(
582
        message: str | Exception,
583
584
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
585
        param: str | None = None,
586
    ) -> ErrorResponse:
587
        return create_error_response(message, err_type, status_code, param)
588

589
    def create_streaming_error_response(
590
        self,
591
        message: str | Exception,
592
593
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
594
        param: str | None = None,
595
    ) -> str:
596
        json_str = json.dumps(
597
            self.create_error_response(
598
599
600
601
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
602
603
            ).model_dump()
        )
604
605
        return json_str

606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

625
    async def _check_model(
626
627
        self,
        request: AnyRequest,
628
    ) -> ErrorResponse | None:
629
630
        error_response = None

631
        if self._is_model_supported(request.model):
632
            return None
633
        if request.model in self.models.lora_requests:
634
            return None
635
636
637
638
639
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
640
641
            if isinstance(load_result, LoRARequest):
                return None
642
643
644
645
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
646
647
648
                error_response = load_result

        return error_response or self.create_error_response(
649
650
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
651
            status_code=HTTPStatus.NOT_FOUND,
652
            param="model",
653
        )
654

655
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

678
    def _maybe_get_adapters(
679
680
681
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
682
    ) -> LoRARequest | None:
683
        if request.model in self.models.lora_requests:
684
            return self.models.lora_requests[request.model]
685
686
687
688
689
690

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
691
                return default_mm_lora
692
693

        if self._is_model_supported(request.model):
694
            return None
695

696
        # if _check_model has been called earlier, this will be unreachable
697
        raise ValueError(f"The model `{request.model}` does not exist.")
698

699
700
701
702
703
704
705
706
707
708
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

709
710
711
712
713
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
714
715
716
717
718
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
719
720
721
722
723
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

724
725
    def _validate_input(
        self,
726
        request: object,
727
        input_ids: list[int],
728
        input_text: str,
729
    ) -> TokensPrompt:
730
        token_num = len(input_ids)
731
        max_model_len = self.model_config.max_model_len
732

733
        # Note: EmbeddingRequest,
734
        # and ScoreRequest doesn't have max_tokens
735
        if isinstance(
736
            request,
737
738
739
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
740
741
742
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
743
744
745
                RerankRequest,
            ),
        ):
746
747
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
748
            if token_num > max_model_len:
749
                operations: dict[type[AnyRequest], str] = {
750
751
752
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
753
                }
754
                operation = operations.get(type(request), "embedding generation")
755
                raise VLLMValidationError(
756
                    f"This model's maximum context length is "
757
                    f"{max_model_len} tokens. However, you requested "
758
                    f"{token_num} tokens in the input for {operation}. "
759
760
761
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
762
                )
763
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
764

765
766
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
767
        if isinstance(
768
769
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
770
        ):
771
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
772

773
774
775
776
777
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
778
            max_tokens = getattr(request, "max_tokens", None)
779
780
781

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
782
        if token_num >= max_model_len:
783
            raise VLLMValidationError(
784
                f"This model's maximum context length is "
785
                f"{max_model_len} tokens. However, your request has "
786
                f"{token_num} input tokens. Please reduce the length of "
787
788
789
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
790
            )
791

792
        if max_tokens is not None and token_num + max_tokens > max_model_len:
793
            raise VLLMValidationError(
794
795
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
796
797
                f"{max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {max_model_len}"
798
799
800
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
801
            )
802

803
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
804

805
806
    def _validate_chat_template(
        self,
807
808
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
809
        trust_request_chat_template: bool,
810
    ) -> ErrorResponse | None:
811
        if not trust_request_chat_template and (
812
813
814
815
816
817
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
818
819
820
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
821
822
                "Refused request with untrusted chat template."
            )
823
824
        return None

825
826
827
828
829
830
831
832
833
834
835
836
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

837
838
839
840
841
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
842
    ) -> list[ProcessorInputs]:
843
844
845
846
847
848
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

849
850
851
852
853
854
        return await self._preprocess_cmpl(request, prompts)

    async def _preprocess_cmpl(
        self,
        request: RendererRequest,
        prompts: Sequence[PromptType | bytes],
855
    ) -> list[ProcessorInputs]:
856
857
858
        renderer = self.renderer
        model_config = self.model_config

859
860
861
862
863
864
865
866
        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
867
        tok_params = request.build_tok_params(model_config)
868

869
870
871
872
873
874
875
876
877
        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )
878

879
880
    async def _preprocess_chat(
        self,
881
        request: RendererChatRequest,
882
        messages: list[ChatCompletionMessageParam],
883
884
885
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
886
        tool_dicts: list[dict[str, Any]] | None = None,
887
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
888
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
889
890
891
892
893
894
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
895
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
896
897
898
            ),
        )

899
900
901
902
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
        ).with_defaults(default_template_kwargs)
903

904
905
906
907
908
909
910
911
912
        (conversation,), (engine_prompt,) = await renderer.render_chat_async(
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
913
        )
914

915
916
917
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
918
919
920
921
922
923
924
925
926
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
927

928
929
930
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
931

932
        return conversation, [engine_prompt]
933

934
    def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
935
936
        return extract_prompt_components(self.model_config, prompt)

937
    def _extract_prompt_text(self, prompt: ProcessorInputs):
938
939
        return self._extract_prompt_components(prompt).text

940
    def _extract_prompt_len(self, prompt: ProcessorInputs):
941
942
        return extract_prompt_len(self.model_config, prompt)

943
944
945
946
947
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
948
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
949
950
951
952
953
954
955
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

956
        _, engine_prompts = await self._preprocess_chat(
957
958
            request,
            new_messages,
959
960
961
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
962
963
964
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
965
        return engine_prompts
966

967
968
969
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
970
        engine_prompt: ProcessorInputs,
971
972
        sampling_params: SamplingParams,
        context: ConversationContext,
973
        lora_request: LoRARequest | None = None,
974
        priority: int = 0,
975
        trace_headers: Mapping[str, str] | None = None,
976
    ):
977
        max_model_len = self.model_config.max_model_len
978

979
        orig_priority = priority
980
        sub_request = 0
981
        while True:
982
983
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
984

985
            self._log_inputs(
986
                sub_request_id,
987
                engine_prompt,
988
989
990
                params=sampling_params,
                lora_request=lora_request,
            )
991

992
            generator = self.engine_client.generate(
993
                engine_prompt,
994
                sampling_params,
995
                sub_request_id,
996
                lora_request=lora_request,
997
                trace_headers=trace_headers,
998
999
                priority=priority,
            )
1000

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1012
            context.append_tool_output(tool_output)
1013
1014
1015
1016
1017

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1018
            # Render the next prompt token ids and update sampling_params.
1019
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1020
                token_ids = context.render_for_completion()
1021
                engine_prompt = token_inputs(token_ids)
1022

1023
                sampling_params.max_tokens = max_model_len - len(token_ids)
1024
            elif isinstance(context, ParsableContext):
1025
                (engine_prompt,) = await self._render_next_turn(
1026
1027
1028
1029
1030
1031
1032
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
1033
1034

                sampling_params.max_tokens = get_max_tokens(
1035
                    max_model_len,
1036
                    context.request.max_output_tokens,
1037
                    self._extract_prompt_len(engine_prompt),
1038
                    self.default_sampling_params,  # type: ignore
1039
                    self.override_max_tokens,  # type: ignore
1040
                )
1041

1042
1043
            # OPTIMIZATION
            priority = orig_priority - 1
1044
            sub_request += 1
1045

1046
1047
1048
    def _log_inputs(
        self,
        request_id: str,
1049
        inputs: PromptType | ProcessorInputs,
1050
1051
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1052
1053
1054
    ) -> None:
        if self.request_logger is None:
            return
1055

1056
        components = self._extract_prompt_components(inputs)
1057
1058
1059

        self.request_logger.log_inputs(
            request_id,
1060
1061
1062
            components.text,
            components.token_ids,
            components.embeds,
1063
1064
1065
            params=params,
            lora_request=lora_request,
        )
1066

1067
1068
1069
    async def _get_trace_headers(
        self,
        headers: Headers,
1070
    ) -> Mapping[str, str] | None:
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1081
    @staticmethod
1082
    def _base_request_id(
1083
1084
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1085
        """Pulls the request id to use from a header, if provided"""
1086
1087
1088
1089
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1090

1091
        return random_uuid() if default is None else default
1092

1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1108
1109
1110
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1111
        tokenizer: TokenizerLike | None,
1112
        enable_auto_tools: bool,
1113
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1151
1152
1153
1154
1155
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1170
                        id=tool_call.id,
1171
1172
1173
1174
1175
1176
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1177
1178
                if content and content.strip() == "":
                    content = None
1179
1180
1181
1182
1183
1184
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1185
    @staticmethod
1186
1187
1188
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1189
        tokenizer: TokenizerLike | None,
1190
1191
        return_as_token_id: bool = False,
    ) -> str:
1192
1193
1194
        if return_as_token_id:
            return f"token_id:{token_id}"

1195
1196
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1197
1198
1199
1200
1201
1202

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1203
        return tokenizer.decode([token_id])
1204

1205
    def _is_model_supported(self, model_name: str | None) -> bool:
1206
1207
        if not model_name:
            return True
1208
        return self.models.is_base_model(model_name)
1209

1210
1211

def clamp_prompt_logprobs(
1212
1213
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1214
1215
1216
1217
1218
1219
1220
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1221
            if logprob_values.logprob == float("-inf"):
1222
1223
                logprob_values.logprob = -9999.0
    return prompt_logprobs