serving.py 43.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import contextlib
5
import json
6
import time
7
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
8
from dataclasses import dataclass, field
9
from http import HTTPStatus
10
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
11

12
import numpy as np
13
from fastapi import Request
14
15
16
from openai.types.responses import (
    ToolChoiceFunction,
)
17
from pydantic import ConfigDict, TypeAdapter, ValidationError
18
from starlette.datastructures import Headers
19

20
import vllm.envs as envs
21
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
22
from vllm.config import ModelConfig
23
from vllm.engine.protocol import EngineClient
24
25
26
27
28
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
29
from vllm.entrypoints.logger import RequestLogger
30
from vllm.entrypoints.openai.chat_completion.protocol import (
31
    ChatCompletionNamedToolChoiceParam,
32
33
    ChatCompletionRequest,
    ChatCompletionResponse,
34
)
35
from vllm.entrypoints.openai.completion.protocol import (
36
37
    CompletionRequest,
    CompletionResponse,
38
39
)
from vllm.entrypoints.openai.engine.protocol import (
40
    ErrorResponse,
41
    FunctionCall,
42
    FunctionDefinition,
43
    GenerationError,
44
)
45
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
46
47
48
49
50
51
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
52
53
54
55
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
56
57
58
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
59
from vllm.entrypoints.openai.speech_to_text.protocol import (
60
61
62
63
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
64
65
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
66
67
    PoolingChatRequest,
    PoolingCompletionRequest,
68
69
70
71
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
72
73
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
74
75
    ScoreRequest,
    ScoreResponse,
76
    ScoreTextRequest,
77
)
78
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
79
80
81
82
83
84
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
85
from vllm.entrypoints.utils import create_error_response, get_max_tokens
86
from vllm.exceptions import VLLMValidationError
87
88
89
90
91
92
93
from vllm.inputs.data import (
    ProcessorInputs,
    PromptType,
    SingletonPrompt,
    TokensPrompt,
    token_inputs,
)
94
from vllm.logger import init_logger
95
from vllm.logprobs import Logprob, PromptLogprobs
96
from vllm.lora.request import LoRARequest
97
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
98
from vllm.pooling_params import PoolingParams
99
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
100
101
102
103
104
105
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
106
from vllm.sampling_params import BeamSearchParams, SamplingParams
107
from vllm.tokenizers import TokenizerLike
108
from vllm.tool_parsers import ToolParser
109
110
111
112
113
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
114
from vllm.utils import random_uuid
115
from vllm.utils.async_utils import (
116
    collect_from_async_generator,
117
118
    merge_async_iterators,
)
119
from vllm.utils.mistral import is_mistral_tokenizer
120
121
122

logger = init_logger(__name__)

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


138
139
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
140
    | TokenizeCompletionRequest
141
    | DetokenizeRequest
142
    | RerankRequest
143
    | ScoreRequest
144
    | PoolingCompletionRequest
145
)
146

147
ChatLikeRequest: TypeAlias = (
148
    ChatCompletionRequest | TokenizeChatRequest | PoolingChatRequest
149
)
150

151
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
152

153
154
155
156
157
158
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
159
    | GenerateRequest
160
161
162
163
164
165
166
167
168
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ScoreResponse
169
    | GenerateResponse
170
)
171
172
173
174

RequestT = TypeVar("RequestT", bound=AnyRequest)


175
@dataclass(kw_only=True)
176
class ServeContext(Generic[RequestT]):
177
    request: RequestT
178
    raw_request: Request | None = None
179
180
    model_name: str
    request_id: str
181
    created_time: int = field(default_factory=lambda: int(time.time()))
182
    lora_request: LoRARequest | None = None
183
    engine_prompts: list[ProcessorInputs] | None = None
184

185
186
187
188
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
189

190
    model_config = ConfigDict(arbitrary_types_allowed=True)
191
192


193
class OpenAIServing:
194
    request_id_prefix: ClassVar[str] = """
195
    A short string prepended to every request’s ID.
196
    """
197

198
199
    def __init__(
        self,
200
        engine_client: EngineClient,
201
        models: OpenAIServingModels,
202
        *,
203
        request_logger: RequestLogger | None,
204
        return_tokens_as_token_ids: bool = False,
205
    ):
206
207
        super().__init__()

208
        self.engine_client = engine_client
209

210
        self.models = models
211

212
        self.request_logger = request_logger
213
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
214

215
216
217
218
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
219
220
221

    async def beam_search(
        self,
222
        prompt: ProcessorInputs,
223
224
        request_id: str,
        params: BeamSearchParams,
225
        lora_request: LoRARequest | None = None,
226
        trace_headers: Mapping[str, str] | None = None,
227
228
229
230
231
232
233
234
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

235
236
237
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
238

239
240
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
241

242
243
244
245
246
247
248
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

249
250
        tokenized_length = len(prompt_token_ids)

251
        logprobs_num = 2 * beam_width
252
        sampling_params = SamplingParams(
253
            logprobs=logprobs_num,
254
255
256
257
258
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
259
                orig_prompt=prompt,
260
261
262
263
264
265
266
267
268
269
270
271
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

272
273
274
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
275
276
277
278
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
279
280
                            prompt_item,
                            sampling_params,
281
                            request_id_item,
282
                            lora_request=lora_request_item,
283
                            trace_headers=trace_headers,
284
285
286
287
288
289
290
291
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
292
293
294
295
296
297
298
299
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

323
324
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
347
                            orig_prompt=prompt,
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
375
                        orig_prompt=prompt,
376
377
378
379
380
381
382
383
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
418

419
420
421
    async def _preprocess(
        self,
        ctx: ServeContext,
422
    ) -> ErrorResponse | None:
423
        """
424
        Default preprocessing hook. Subclasses may override to prepare `ctx`.
425
426
427
428
429
430
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
431
    ) -> AnyResponse | ErrorResponse:
432
433
434
435
436
437
438
439
440
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
441
    ) -> AnyResponse | ErrorResponse:
442
        async for response in self._pipeline(ctx):
443
444
445
446
447
448
449
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
450
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

471
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
472
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
473

474
475
        if (
            truncate_prompt_tokens is not None
476
            and truncate_prompt_tokens > self.model_config.max_model_len
477
        ):
478
479
480
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
481
482
                " Please, select a smaller truncation size."
            )
483
484
        return None

485
486
487
    def _create_pooling_params(
        self,
        ctx: ServeContext,
488
    ) -> PoolingParams | ErrorResponse:
489
490
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
491
492
                "Request type does not support pooling parameters"
            )
493
494
495

        return ctx.request.to_pooling_params()

496
497
498
    async def _prepare_generators(
        self,
        ctx: ServeContext,
499
    ) -> ErrorResponse | None:
500
        """Schedule the request and get the result generator."""
501
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
502

503
504
505
506
507
        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )
508

509
510
511
        pooling_params = self._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
512

513
514
        if ctx.engine_prompts is None:
            return self.create_error_response("Engine prompts not available")
515

516
517
        for i, engine_prompt in enumerate(ctx.engine_prompts):
            request_id_item = f"{ctx.request_id}-{i}"
518

519
520
521
522
523
524
            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
525

526
527
528
529
530
531
532
533
            generator = self.engine_client.encode(
                engine_prompt,
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )
534

535
            generators.append(generator)
536

537
        ctx.result_generator = merge_async_iterators(*generators)
538

539
        return None
540
541
542
543

    async def _collect_batch(
        self,
        ctx: ServeContext,
544
    ) -> ErrorResponse | None:
545
        """Collect batch results from the result generator."""
546
547
        if ctx.engine_prompts is None:
            return self.create_error_response("Engine prompts not available")
548

549
550
551
        num_prompts = len(ctx.engine_prompts)
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
552

553
554
        if ctx.result_generator is None:
            return self.create_error_response("Result generator not available")
555

556
557
        async for i, res in ctx.result_generator:
            final_res_batch[i] = res
558

559
560
561
562
        if None in final_res_batch:
            return self.create_error_response(
                "Failed to generate results for all prompts"
            )
563

564
        ctx.final_res_batch = [res for res in final_res_batch if res is not None]
565

566
        return None
567

568
    @staticmethod
569
    def create_error_response(
570
        message: str | Exception,
571
572
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
573
        param: str | None = None,
574
    ) -> ErrorResponse:
575
        return create_error_response(message, err_type, status_code, param)
576

577
    def create_streaming_error_response(
578
        self,
579
        message: str | Exception,
580
581
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
582
        param: str | None = None,
583
    ) -> str:
584
        json_str = json.dumps(
585
            self.create_error_response(
586
587
588
589
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
590
591
            ).model_dump()
        )
592
593
        return json_str

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

613
    async def _check_model(
614
615
        self,
        request: AnyRequest,
616
    ) -> ErrorResponse | None:
617
618
        error_response = None

619
        if self._is_model_supported(request.model):
620
            return None
621
        if request.model in self.models.lora_requests:
622
            return None
623
624
625
626
627
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
628
629
            if isinstance(load_result, LoRARequest):
                return None
630
631
632
633
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
634
635
636
                error_response = load_result

        return error_response or self.create_error_response(
637
638
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
639
            status_code=HTTPStatus.NOT_FOUND,
640
            param="model",
641
        )
642

643
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

666
    def _maybe_get_adapters(
667
668
669
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
670
    ) -> LoRARequest | None:
671
        if request.model in self.models.lora_requests:
672
            return self.models.lora_requests[request.model]
673
674
675
676
677
678

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
679
                return default_mm_lora
680
681

        if self._is_model_supported(request.model):
682
            return None
683

684
        # if _check_model has been called earlier, this will be unreachable
685
        raise ValueError(f"The model `{request.model}` does not exist.")
686

687
688
689
690
691
692
693
694
695
696
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

697
698
699
700
701
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
702
703
704
705
706
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
707
708
709
710
711
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

712
713
    def _validate_input(
        self,
714
        request: object,
715
        input_ids: list[int],
716
        input_text: str,
717
    ) -> TokensPrompt:
718
        token_num = len(input_ids)
719
        max_model_len = self.model_config.max_model_len
720

721
        # Note: ScoreRequest doesn't have max_tokens
722
        if isinstance(
723
            request,
724
            (
725
726
727
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
728
729
730
                RerankRequest,
            ),
        ):
731
732
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
733
            if token_num > max_model_len:
734
                operations: dict[type[AnyRequest], str] = {
735
736
737
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
738
                }
739
                operation = operations.get(type(request), "embedding generation")
740
                raise VLLMValidationError(
741
                    f"This model's maximum context length is "
742
                    f"{max_model_len} tokens. However, you requested "
743
                    f"{token_num} tokens in the input for {operation}. "
744
745
746
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
747
                )
748
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
749

750
751
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
752
        if isinstance(
753
754
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
755
        ):
756
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
757

758
759
760
761
762
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
763
            max_tokens = getattr(request, "max_tokens", None)
764
765
766

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
767
        if token_num >= max_model_len:
768
            raise VLLMValidationError(
769
                f"This model's maximum context length is "
770
                f"{max_model_len} tokens. However, your request has "
771
                f"{token_num} input tokens. Please reduce the length of "
772
773
774
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
775
            )
776

777
        if max_tokens is not None and token_num + max_tokens > max_model_len:
778
            raise VLLMValidationError(
779
780
781
782
783
784
785
786
787
                f"This model's maximum context length is "
                f"{max_model_len} tokens. However, you requested "
                f"{max_tokens} output tokens and your prompt contains "
                f"{token_num} input tokens, for a total of "
                f"{token_num + max_tokens} tokens "
                f"({token_num} + {max_tokens} = "
                f"{token_num + max_tokens} > {max_model_len}). "
                f"Please reduce the length of the input prompt or the "
                f"number of requested output tokens.",
788
789
                parameter="max_tokens",
                value=max_tokens,
790
            )
791

792
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
793

794
795
    def _validate_chat_template(
        self,
796
797
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
798
        trust_request_chat_template: bool,
799
    ) -> ErrorResponse | None:
800
        if not trust_request_chat_template and (
801
802
803
804
805
806
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
807
808
809
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
810
811
                "Refused request with untrusted chat template."
            )
812
813
        return None

814
815
816
817
818
819
820
821
822
823
824
825
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

826
827
828
829
830
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
831
    ) -> list[ProcessorInputs]:
832
833
834
835
836
837
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

838
839
840
841
842
843
        return await self._preprocess_cmpl(request, prompts)

    async def _preprocess_cmpl(
        self,
        request: RendererRequest,
        prompts: Sequence[PromptType | bytes],
844
    ) -> list[ProcessorInputs]:
845
846
847
        renderer = self.renderer
        model_config = self.model_config

848
849
850
851
852
853
854
855
        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
856
        tok_params = request.build_tok_params(model_config)
857

858
859
860
861
862
863
864
865
866
        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )
867

868
869
    async def _preprocess_chat(
        self,
870
        request: RendererChatRequest,
871
        messages: list[ChatCompletionMessageParam],
872
873
874
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
875
        tool_dicts: list[dict[str, Any]] | None = None,
876
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
877
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
878
879
880
881
882
883
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
884
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
885
886
887
            ),
        )

888
889
        mm_config = self.model_config.multimodal_config

890
891
892
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
893
894
895
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
896
            default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
897
        )
898

899
900
901
902
903
904
905
906
907
        (conversation,), (engine_prompt,) = await renderer.render_chat_async(
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
908
        )
909

910
911
912
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
913
914
915
916
917
918
919
920
921
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
922

923
924
925
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
926

927
        return conversation, [engine_prompt]
928

929
    def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
930
931
        return extract_prompt_components(self.model_config, prompt)

932
    def _extract_prompt_text(self, prompt: ProcessorInputs):
933
934
        return self._extract_prompt_components(prompt).text

935
    def _extract_prompt_len(self, prompt: ProcessorInputs):
936
937
        return extract_prompt_len(self.model_config, prompt)

938
939
940
941
942
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
943
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
944
945
946
947
948
949
950
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

951
        _, engine_prompts = await self._preprocess_chat(
952
953
            request,
            new_messages,
954
955
956
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
957
958
959
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
960
        return engine_prompts
961

962
963
964
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
965
        engine_prompt: ProcessorInputs,
966
967
        sampling_params: SamplingParams,
        context: ConversationContext,
968
        lora_request: LoRARequest | None = None,
969
        priority: int = 0,
970
        trace_headers: Mapping[str, str] | None = None,
971
    ):
972
        max_model_len = self.model_config.max_model_len
973

974
        orig_priority = priority
975
        sub_request = 0
976
        while True:
977
978
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
979

980
            self._log_inputs(
981
                sub_request_id,
982
                engine_prompt,
983
984
985
                params=sampling_params,
                lora_request=lora_request,
            )
986

987
            generator = self.engine_client.generate(
988
                engine_prompt,
989
                sampling_params,
990
                sub_request_id,
991
                lora_request=lora_request,
992
                trace_headers=trace_headers,
993
994
                priority=priority,
            )
995

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1007
            context.append_tool_output(tool_output)
1008
1009
1010
1011
1012

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1013
            # Render the next prompt token ids and update sampling_params.
1014
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1015
                token_ids = context.render_for_completion()
1016
                engine_prompt = token_inputs(token_ids)
1017

1018
                sampling_params.max_tokens = max_model_len - len(token_ids)
1019
            elif isinstance(context, ParsableContext):
1020
                (engine_prompt,) = await self._render_next_turn(
1021
1022
1023
1024
1025
1026
1027
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
1028
1029

                sampling_params.max_tokens = get_max_tokens(
1030
                    max_model_len,
1031
                    context.request.max_output_tokens,
1032
                    self._extract_prompt_len(engine_prompt),
1033
                    self.default_sampling_params,  # type: ignore
1034
                    self.override_max_tokens,  # type: ignore
1035
                )
1036

1037
1038
            # OPTIMIZATION
            priority = orig_priority - 1
1039
            sub_request += 1
1040

1041
1042
1043
    def _log_inputs(
        self,
        request_id: str,
1044
        inputs: PromptType | ProcessorInputs,
1045
1046
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1047
1048
1049
    ) -> None:
        if self.request_logger is None:
            return
1050

1051
        components = self._extract_prompt_components(inputs)
1052
1053
1054

        self.request_logger.log_inputs(
            request_id,
1055
1056
1057
            components.text,
            components.token_ids,
            components.embeds,
1058
1059
1060
            params=params,
            lora_request=lora_request,
        )
1061

1062
1063
1064
    async def _get_trace_headers(
        self,
        headers: Headers,
1065
    ) -> Mapping[str, str] | None:
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1076
    @staticmethod
1077
    def _base_request_id(
1078
1079
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1080
        """Pulls the request id to use from a header, if provided"""
1081
1082
1083
1084
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1085

1086
        return random_uuid() if default is None else default
1087

1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1103
1104
1105
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1106
        tokenizer: TokenizerLike | None,
1107
        enable_auto_tools: bool,
1108
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
1129
1130
1131
1132
1133
1134
1135
1136
            tool_calls = []
            with contextlib.suppress(ValidationError):
                content = content or ""
                tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
                    content
                )
            for tool_call in tool_calls:
                function_calls.append(
1137
1138
1139
1140
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
1141
                )
1142
1143
1144
1145
1146
1147
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1148
1149
1150
1151
1152
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1167
                        id=tool_call.id,
1168
1169
1170
1171
1172
1173
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1174
1175
                if content and content.strip() == "":
                    content = None
1176
1177
1178
1179
1180
1181
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1182
    @staticmethod
1183
1184
1185
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1186
        tokenizer: TokenizerLike | None,
1187
1188
        return_as_token_id: bool = False,
    ) -> str:
1189
1190
1191
        if return_as_token_id:
            return f"token_id:{token_id}"

1192
1193
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1194
1195
1196
1197
1198
1199

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1200
        return tokenizer.decode([token_id])
1201

1202
    def _is_model_supported(self, model_name: str | None) -> bool:
1203
1204
        if not model_name:
            return True
1205
        return self.models.is_base_model(model_name)
1206

1207
1208

def clamp_prompt_logprobs(
1209
1210
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1211
1212
1213
1214
1215
1216
1217
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1218
            if logprob_values.logprob == float("-inf"):
1219
1220
                logprob_values.logprob = -9999.0
    return prompt_logprobs