serving.py 42.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import time
6
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
7
from dataclasses import dataclass, field
8
from http import HTTPStatus
9
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
10

11
import numpy as np
12
from fastapi import Request
13
14
15
from openai.types.responses import (
    ToolChoiceFunction,
)
16
17
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
18

19
import vllm.envs as envs
20
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
21
from vllm.config import ModelConfig
22
from vllm.engine.protocol import EngineClient
23
24
25
26
27
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
28
from vllm.entrypoints.logger import RequestLogger
29
from vllm.entrypoints.openai.chat_completion.protocol import (
30
    ChatCompletionNamedToolChoiceParam,
31
32
    ChatCompletionRequest,
    ChatCompletionResponse,
33
)
34
from vllm.entrypoints.openai.completion.protocol import (
35
36
    CompletionRequest,
    CompletionResponse,
37
38
)
from vllm.entrypoints.openai.engine.protocol import (
39
    ErrorResponse,
40
    FunctionCall,
41
    FunctionDefinition,
42
    GenerationError,
43
)
44
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
45
46
47
48
49
50
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
51
52
53
54
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
55
56
57
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
58
from vllm.entrypoints.openai.speech_to_text.protocol import (
59
60
61
62
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
63
64
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
65
66
    PoolingChatRequest,
    PoolingCompletionRequest,
67
68
69
70
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
71
72
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
73
74
    ScoreRequest,
    ScoreResponse,
75
    ScoreTextRequest,
76
)
77
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
78
79
80
81
82
83
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
84
from vllm.entrypoints.utils import create_error_response, get_max_tokens
85
from vllm.exceptions import VLLMValidationError
86
87
88
89
90
91
92
from vllm.inputs.data import (
    ProcessorInputs,
    PromptType,
    SingletonPrompt,
    TokensPrompt,
    token_inputs,
)
93
from vllm.logger import init_logger
94
from vllm.logprobs import Logprob, PromptLogprobs
95
from vllm.lora.request import LoRARequest
96
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
97
from vllm.pooling_params import PoolingParams
98
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
99
100
101
102
103
104
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
105
from vllm.sampling_params import BeamSearchParams, SamplingParams
106
from vllm.tokenizers import TokenizerLike
107
from vllm.tool_parsers import ToolParser
108
109
110
111
112
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
113
from vllm.utils import random_uuid
114
from vllm.utils.async_utils import (
115
    collect_from_async_generator,
116
117
    merge_async_iterators,
)
118
from vllm.utils.mistral import is_mistral_tokenizer
119
120
121

logger = init_logger(__name__)

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


137
138
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
139
    | TokenizeCompletionRequest
140
    | DetokenizeRequest
141
    | RerankRequest
142
    | ScoreRequest
143
    | PoolingCompletionRequest
144
)
145

146
ChatLikeRequest: TypeAlias = (
147
    ChatCompletionRequest | TokenizeChatRequest | PoolingChatRequest
148
)
149

150
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
151

152
153
154
155
156
157
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
158
    | GenerateRequest
159
160
161
162
163
164
165
166
167
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ScoreResponse
168
    | GenerateResponse
169
)
170
171
172
173

RequestT = TypeVar("RequestT", bound=AnyRequest)


174
@dataclass(kw_only=True)
175
class ServeContext(Generic[RequestT]):
176
    request: RequestT
177
    raw_request: Request | None = None
178
179
    model_name: str
    request_id: str
180
    created_time: int = field(default_factory=lambda: int(time.time()))
181
    lora_request: LoRARequest | None = None
182
    engine_prompts: list[ProcessorInputs] | None = None
183

184
185
186
187
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
188

189
    model_config = ConfigDict(arbitrary_types_allowed=True)
190
191


192
class OpenAIServing:
193
    request_id_prefix: ClassVar[str] = """
194
    A short string prepended to every request’s ID.
195
    """
196

197
198
    def __init__(
        self,
199
        engine_client: EngineClient,
200
        models: OpenAIServingModels,
201
        *,
202
        request_logger: RequestLogger | None,
203
        return_tokens_as_token_ids: bool = False,
204
    ):
205
206
        super().__init__()

207
        self.engine_client = engine_client
208

209
        self.models = models
210

211
        self.request_logger = request_logger
212
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
213

214
215
216
217
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
218
219
220

    async def beam_search(
        self,
221
        prompt: ProcessorInputs,
222
223
        request_id: str,
        params: BeamSearchParams,
224
        lora_request: LoRARequest | None = None,
225
        trace_headers: Mapping[str, str] | None = None,
226
227
228
229
230
231
232
233
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

234
235
236
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
237

238
239
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
240

241
242
243
244
245
246
247
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

248
249
        tokenized_length = len(prompt_token_ids)

250
        logprobs_num = 2 * beam_width
251
        sampling_params = SamplingParams(
252
            logprobs=logprobs_num,
253
254
255
256
257
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
258
                orig_prompt=prompt,
259
260
261
262
263
264
265
266
267
268
269
270
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

271
272
273
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
274
275
276
277
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
278
279
                            prompt_item,
                            sampling_params,
280
                            request_id_item,
281
                            lora_request=lora_request_item,
282
                            trace_headers=trace_headers,
283
284
285
286
287
288
289
290
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
291
292
293
294
295
296
297
298
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

322
323
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
346
                            orig_prompt=prompt,
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
374
                        orig_prompt=prompt,
375
376
377
378
379
380
381
382
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
417

418
419
420
    async def _preprocess(
        self,
        ctx: ServeContext,
421
    ) -> ErrorResponse | None:
422
        """
423
        Default preprocessing hook. Subclasses may override to prepare `ctx`.
424
425
426
427
428
429
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
430
    ) -> AnyResponse | ErrorResponse:
431
432
433
434
435
436
437
438
439
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
440
    ) -> AnyResponse | ErrorResponse:
441
        async for response in self._pipeline(ctx):
442
443
444
445
446
447
448
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
449
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

470
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
471
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
472

473
474
        if (
            truncate_prompt_tokens is not None
475
            and truncate_prompt_tokens > self.model_config.max_model_len
476
        ):
477
478
479
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
480
481
                " Please, select a smaller truncation size."
            )
482
483
        return None

484
485
486
    def _create_pooling_params(
        self,
        ctx: ServeContext,
487
    ) -> PoolingParams | ErrorResponse:
488
489
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
490
491
                "Request type does not support pooling parameters"
            )
492
493
494

        return ctx.request.to_pooling_params()

495
496
497
    async def _prepare_generators(
        self,
        ctx: ServeContext,
498
    ) -> ErrorResponse | None:
499
        """Schedule the request and get the result generator."""
500
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
501

502
503
504
505
506
        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )
507

508
509
510
        pooling_params = self._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
511

512
513
        if ctx.engine_prompts is None:
            return self.create_error_response("Engine prompts not available")
514

515
516
        for i, engine_prompt in enumerate(ctx.engine_prompts):
            request_id_item = f"{ctx.request_id}-{i}"
517

518
519
520
521
522
523
            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
524

525
526
527
528
529
530
531
532
            generator = self.engine_client.encode(
                engine_prompt,
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )
533

534
            generators.append(generator)
535

536
        ctx.result_generator = merge_async_iterators(*generators)
537

538
        return None
539
540
541
542

    async def _collect_batch(
        self,
        ctx: ServeContext,
543
    ) -> ErrorResponse | None:
544
        """Collect batch results from the result generator."""
545
546
        if ctx.engine_prompts is None:
            return self.create_error_response("Engine prompts not available")
547

548
549
550
        num_prompts = len(ctx.engine_prompts)
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
551

552
553
        if ctx.result_generator is None:
            return self.create_error_response("Result generator not available")
554

555
556
        async for i, res in ctx.result_generator:
            final_res_batch[i] = res
557

558
559
560
561
        if None in final_res_batch:
            return self.create_error_response(
                "Failed to generate results for all prompts"
            )
562

563
        ctx.final_res_batch = [res for res in final_res_batch if res is not None]
564

565
        return None
566

567
    @staticmethod
568
    def create_error_response(
569
        message: str | Exception,
570
571
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
572
        param: str | None = None,
573
    ) -> ErrorResponse:
574
        return create_error_response(message, err_type, status_code, param)
575

576
    def create_streaming_error_response(
577
        self,
578
        message: str | Exception,
579
580
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
581
        param: str | None = None,
582
    ) -> str:
583
        json_str = json.dumps(
584
            self.create_error_response(
585
586
587
588
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
589
590
            ).model_dump()
        )
591
592
        return json_str

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

612
    async def _check_model(
613
614
        self,
        request: AnyRequest,
615
    ) -> ErrorResponse | None:
616
617
        error_response = None

618
        if self._is_model_supported(request.model):
619
            return None
620
        if request.model in self.models.lora_requests:
621
            return None
622
623
624
625
626
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
627
628
            if isinstance(load_result, LoRARequest):
                return None
629
630
631
632
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
633
634
635
                error_response = load_result

        return error_response or self.create_error_response(
636
637
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
638
            status_code=HTTPStatus.NOT_FOUND,
639
            param="model",
640
        )
641

642
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

665
    def _maybe_get_adapters(
666
667
668
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
669
    ) -> LoRARequest | None:
670
        if request.model in self.models.lora_requests:
671
            return self.models.lora_requests[request.model]
672
673
674
675
676
677

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
678
                return default_mm_lora
679
680

        if self._is_model_supported(request.model):
681
            return None
682

683
        # if _check_model has been called earlier, this will be unreachable
684
        raise ValueError(f"The model `{request.model}` does not exist.")
685

686
687
688
689
690
691
692
693
694
695
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

696
697
698
699
700
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
701
702
703
704
705
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
706
707
708
709
710
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

711
712
    def _validate_input(
        self,
713
        request: object,
714
        input_ids: list[int],
715
        input_text: str,
716
    ) -> TokensPrompt:
717
        token_num = len(input_ids)
718
        max_model_len = self.model_config.max_model_len
719

720
        # Note: ScoreRequest doesn't have max_tokens
721
        if isinstance(
722
            request,
723
            (
724
725
726
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
727
728
729
                RerankRequest,
            ),
        ):
730
731
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
732
            if token_num > max_model_len:
733
                operations: dict[type[AnyRequest], str] = {
734
735
736
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
737
                }
738
                operation = operations.get(type(request), "embedding generation")
739
                raise VLLMValidationError(
740
                    f"This model's maximum context length is "
741
                    f"{max_model_len} tokens. However, you requested "
742
                    f"{token_num} tokens in the input for {operation}. "
743
744
745
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
746
                )
747
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
748

749
750
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
751
        if isinstance(
752
753
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
754
        ):
755
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
756

757
758
759
760
761
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
762
            max_tokens = getattr(request, "max_tokens", None)
763
764
765

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
766
        if token_num >= max_model_len:
767
            raise VLLMValidationError(
768
                f"This model's maximum context length is "
769
                f"{max_model_len} tokens. However, your request has "
770
                f"{token_num} input tokens. Please reduce the length of "
771
772
773
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
774
            )
775

776
        if max_tokens is not None and token_num + max_tokens > max_model_len:
777
            raise VLLMValidationError(
778
779
780
781
782
783
784
785
786
                f"This model's maximum context length is "
                f"{max_model_len} tokens. However, you requested "
                f"{max_tokens} output tokens and your prompt contains "
                f"{token_num} input tokens, for a total of "
                f"{token_num + max_tokens} tokens "
                f"({token_num} + {max_tokens} = "
                f"{token_num + max_tokens} > {max_model_len}). "
                f"Please reduce the length of the input prompt or the "
                f"number of requested output tokens.",
787
788
                parameter="max_tokens",
                value=max_tokens,
789
            )
790

791
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
792

793
794
    def _validate_chat_template(
        self,
795
796
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
797
        trust_request_chat_template: bool,
798
    ) -> ErrorResponse | None:
799
        if not trust_request_chat_template and (
800
801
802
803
804
805
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
806
807
808
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
809
810
                "Refused request with untrusted chat template."
            )
811
812
        return None

813
814
815
816
817
818
819
820
821
822
823
824
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

825
826
827
828
829
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
830
    ) -> list[ProcessorInputs]:
831
832
833
834
835
836
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

837
838
839
840
841
842
        return await self._preprocess_cmpl(request, prompts)

    async def _preprocess_cmpl(
        self,
        request: RendererRequest,
        prompts: Sequence[PromptType | bytes],
843
    ) -> list[ProcessorInputs]:
844
845
846
        renderer = self.renderer
        model_config = self.model_config

847
848
849
850
851
852
853
854
        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
855
        tok_params = request.build_tok_params(model_config)
856

857
858
859
860
861
862
863
864
865
        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )
866

867
868
    async def _preprocess_chat(
        self,
869
        request: RendererChatRequest,
870
        messages: list[ChatCompletionMessageParam],
871
872
873
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
874
        tool_dicts: list[dict[str, Any]] | None = None,
875
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
876
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
877
878
879
880
881
882
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
883
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
884
885
886
            ),
        )

887
888
        mm_config = self.model_config.multimodal_config

889
890
891
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
892
893
894
895
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
        )
896

897
898
899
900
901
902
903
904
905
        (conversation,), (engine_prompt,) = await renderer.render_chat_async(
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
906
        )
907

908
909
910
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
911
912
913
914
915
916
917
918
919
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
920

921
922
923
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
924

925
        return conversation, [engine_prompt]
926

927
    def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
928
929
        return extract_prompt_components(self.model_config, prompt)

930
    def _extract_prompt_text(self, prompt: ProcessorInputs):
931
932
        return self._extract_prompt_components(prompt).text

933
    def _extract_prompt_len(self, prompt: ProcessorInputs):
934
935
        return extract_prompt_len(self.model_config, prompt)

936
937
938
939
940
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
941
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
942
943
944
945
946
947
948
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

949
        _, engine_prompts = await self._preprocess_chat(
950
951
            request,
            new_messages,
952
953
954
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
955
956
957
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
958
        return engine_prompts
959

960
961
962
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
963
        engine_prompt: ProcessorInputs,
964
965
        sampling_params: SamplingParams,
        context: ConversationContext,
966
        lora_request: LoRARequest | None = None,
967
        priority: int = 0,
968
        trace_headers: Mapping[str, str] | None = None,
969
    ):
970
        max_model_len = self.model_config.max_model_len
971

972
        orig_priority = priority
973
        sub_request = 0
974
        while True:
975
976
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
977

978
            self._log_inputs(
979
                sub_request_id,
980
                engine_prompt,
981
982
983
                params=sampling_params,
                lora_request=lora_request,
            )
984

985
            generator = self.engine_client.generate(
986
                engine_prompt,
987
                sampling_params,
988
                sub_request_id,
989
                lora_request=lora_request,
990
                trace_headers=trace_headers,
991
992
                priority=priority,
            )
993

994
995
996
997
998
999
1000
1001
1002
1003
1004
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1005
            context.append_tool_output(tool_output)
1006
1007
1008
1009
1010

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1011
            # Render the next prompt token ids and update sampling_params.
1012
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1013
                token_ids = context.render_for_completion()
1014
                engine_prompt = token_inputs(token_ids)
1015

1016
                sampling_params.max_tokens = max_model_len - len(token_ids)
1017
            elif isinstance(context, ParsableContext):
1018
                (engine_prompt,) = await self._render_next_turn(
1019
1020
1021
1022
1023
1024
1025
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
1026
1027

                sampling_params.max_tokens = get_max_tokens(
1028
                    max_model_len,
1029
                    context.request.max_output_tokens,
1030
                    self._extract_prompt_len(engine_prompt),
1031
                    self.default_sampling_params,  # type: ignore
1032
                    self.override_max_tokens,  # type: ignore
1033
                )
1034

1035
1036
            # OPTIMIZATION
            priority = orig_priority - 1
1037
            sub_request += 1
1038

1039
1040
1041
    def _log_inputs(
        self,
        request_id: str,
1042
        inputs: PromptType | ProcessorInputs,
1043
1044
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1045
1046
1047
    ) -> None:
        if self.request_logger is None:
            return
1048

1049
        components = self._extract_prompt_components(inputs)
1050
1051
1052

        self.request_logger.log_inputs(
            request_id,
1053
1054
1055
            components.text,
            components.token_ids,
            components.embeds,
1056
1057
1058
            params=params,
            lora_request=lora_request,
        )
1059

1060
1061
1062
    async def _get_trace_headers(
        self,
        headers: Headers,
1063
    ) -> Mapping[str, str] | None:
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1074
    @staticmethod
1075
    def _base_request_id(
1076
1077
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1078
        """Pulls the request id to use from a header, if provided"""
1079
1080
1081
1082
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1083

1084
        return random_uuid() if default is None else default
1085

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1101
1102
1103
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1104
        tokenizer: TokenizerLike | None,
1105
        enable_auto_tools: bool,
1106
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1144
1145
1146
1147
1148
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1163
                        id=tool_call.id,
1164
1165
1166
1167
1168
1169
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1170
1171
                if content and content.strip() == "":
                    content = None
1172
1173
1174
1175
1176
1177
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1178
    @staticmethod
1179
1180
1181
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1182
        tokenizer: TokenizerLike | None,
1183
1184
        return_as_token_id: bool = False,
    ) -> str:
1185
1186
1187
        if return_as_token_id:
            return f"token_id:{token_id}"

1188
1189
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1190
1191
1192
1193
1194
1195

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1196
        return tokenizer.decode([token_id])
1197

1198
    def _is_model_supported(self, model_name: str | None) -> bool:
1199
1200
        if not model_name:
            return True
1201
        return self.models.is_base_model(model_name)
1202

1203
1204

def clamp_prompt_logprobs(
1205
1206
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1207
1208
1209
1210
1211
1212
1213
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1214
            if logprob_values.logprob == float("-inf"):
1215
1216
                logprob_values.logprob = -9999.0
    return prompt_logprobs