serving.py 45.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.config import ModelConfig
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
30
from vllm.entrypoints.logger import RequestLogger
31
from vllm.entrypoints.openai.chat_completion.protocol import (
32
    ChatCompletionNamedToolChoiceParam,
33
34
    ChatCompletionRequest,
    ChatCompletionResponse,
35
)
36
from vllm.entrypoints.openai.completion.protocol import (
37
38
    CompletionRequest,
    CompletionResponse,
39
40
)
from vllm.entrypoints.openai.engine.protocol import (
41
42
    ErrorInfo,
    ErrorResponse,
43
    FunctionCall,
44
    FunctionDefinition,
45
)
46
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
47
48
49
50
51
52
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
53
54
55
56
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
57
58
59
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
60
from vllm.entrypoints.openai.speech_to_text.protocol import (
61
62
63
64
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
65
from vllm.entrypoints.pooling.embed.protocol import (
66
    EmbeddingBytesResponse,
67
68
69
70
71
72
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
73
74
    PoolingChatRequest,
    PoolingCompletionRequest,
75
76
77
78
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
79
80
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
81
82
    ScoreRequest,
    ScoreResponse,
83
    ScoreTextRequest,
84
)
85
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
86
87
88
89
90
91
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
92
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
93
from vllm.exceptions import VLLMValidationError
94
95
96
97
98
99
100
from vllm.inputs.data import (
    ProcessorInputs,
    PromptType,
    SingletonPrompt,
    TokensPrompt,
    token_inputs,
)
101
from vllm.logger import init_logger
102
from vllm.logprobs import Logprob, PromptLogprobs
103
from vllm.lora.request import LoRARequest
104
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
105
from vllm.pooling_params import PoolingParams
106
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
107
108
109
110
111
112
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
113
from vllm.sampling_params import BeamSearchParams, SamplingParams
114
from vllm.tokenizers import TokenizerLike
115
from vllm.tool_parsers import ToolParser
116
117
118
119
120
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
121
from vllm.utils import random_uuid
122
from vllm.utils.async_utils import (
123
    collect_from_async_generator,
124
125
    merge_async_iterators,
)
126
from vllm.utils.mistral import is_mistral_tokenizer
127

128
129
130
131
132
133
134
135
136

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


137
138
logger = init_logger(__name__)

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


154
155
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
156
    | TokenizeCompletionRequest
157
158
    | DetokenizeRequest
    | EmbeddingCompletionRequest
159
    | RerankRequest
160
    | ScoreRequest
161
    | PoolingCompletionRequest
162
)
163

164
ChatLikeRequest: TypeAlias = (
165
166
    ChatCompletionRequest
    | TokenizeChatRequest
167
168
    | EmbeddingChatRequest
    | PoolingChatRequest
169
)
170

171
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
172

173
174
175
176
177
178
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
179
    | GenerateRequest
180
181
182
183
184
185
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
186
    | EmbeddingBytesResponse
187
188
189
190
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ScoreResponse
191
    | GenerateResponse
192
)
193
194
195
196

RequestT = TypeVar("RequestT", bound=AnyRequest)


197
@dataclass(kw_only=True)
198
class ServeContext(Generic[RequestT]):
199
    request: RequestT
200
    raw_request: Request | None = None
201
202
    model_name: str
    request_id: str
203
    created_time: int = field(default_factory=lambda: int(time.time()))
204
    lora_request: LoRARequest | None = None
205
    engine_prompts: list[ProcessorInputs] | None = None
206

207
208
209
210
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
211

212
    model_config = ConfigDict(arbitrary_types_allowed=True)
213
214


215
class OpenAIServing:
216
    request_id_prefix: ClassVar[str] = """
217
218
    A short string prepended to every request’s ID (e.g. "embd")
    so you can easily tell “this ID came from Embedding.”
219
    """
220

221
222
    def __init__(
        self,
223
        engine_client: EngineClient,
224
        models: OpenAIServingModels,
225
        *,
226
        request_logger: RequestLogger | None,
227
        return_tokens_as_token_ids: bool = False,
228
        log_error_stack: bool = False,
229
    ):
230
231
        super().__init__()

232
        self.engine_client = engine_client
233

234
        self.models = models
235

236
        self.request_logger = request_logger
237
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
238

239
        self.log_error_stack = log_error_stack
240

241
242
243
244
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
245
246
247

    async def beam_search(
        self,
248
        prompt: ProcessorInputs,
249
250
        request_id: str,
        params: BeamSearchParams,
251
        lora_request: LoRARequest | None = None,
252
        trace_headers: Mapping[str, str] | None = None,
253
254
255
256
257
258
259
260
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

261
262
263
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
264

265
266
267
268
269
270
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
        if prompt["type"] == "enc_dec":
            raise NotImplementedError(
                "Encoder-decoder prompt not supported for beam search"
            )
271

272
273
        prompt_text = prompt.get("prompt")
        prompt_token_ids = prompt["prompt_token_ids"]
274
275
        tokenized_length = len(prompt_token_ids)

276
        logprobs_num = 2 * beam_width
277
        sampling_params = SamplingParams(
278
            logprobs=logprobs_num,
279
280
281
282
283
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
284
                orig_prompt=prompt,
285
286
287
288
289
290
291
292
293
294
295
296
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

297
298
299
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
300
301
302
303
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
304
305
                            prompt_item,
                            sampling_params,
306
                            request_id_item,
307
                            lora_request=lora_request_item,
308
                            trace_headers=trace_headers,
309
310
311
312
313
314
315
316
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
317
318
319
320
321
322
323
324
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

348
349
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
372
                            orig_prompt=prompt,
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
400
                        orig_prompt=prompt,
401
402
403
404
405
406
407
408
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
443

444
445
446
    async def _preprocess(
        self,
        ctx: ServeContext,
447
    ) -> ErrorResponse | None:
448
449
        """
        Default preprocessing hook. Subclasses may override
450
        to prepare `ctx` (embedding, etc.).
451
452
453
454
455
456
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
457
    ) -> AnyResponse | ErrorResponse:
458
459
460
461
462
463
464
465
466
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
467
    ) -> AnyResponse | ErrorResponse:
468
        async for response in self._pipeline(ctx):
469
470
471
472
473
474
475
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
476
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

497
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
498
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
499

500
501
        if (
            truncate_prompt_tokens is not None
502
            and truncate_prompt_tokens > self.model_config.max_model_len
503
        ):
504
505
506
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
507
508
                " Please, select a smaller truncation size."
            )
509
510
        return None

511
512
513
    def _create_pooling_params(
        self,
        ctx: ServeContext,
514
    ) -> PoolingParams | ErrorResponse:
515
516
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
517
518
                "Request type does not support pooling parameters"
            )
519
520
521

        return ctx.request.to_pooling_params()

522
523
524
    async def _prepare_generators(
        self,
        ctx: ServeContext,
525
    ) -> ErrorResponse | None:
526
        """Schedule the request and get the result generator."""
527
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
528
529

        try:
530
531
532
533
534
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
535

536
537
538
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
539
540

            if ctx.engine_prompts is None:
541
                return self.create_error_response("Engine prompts not available")
542
543
544
545

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

546
547
                self._log_inputs(
                    request_id_item,
548
                    engine_prompt,
549
550
551
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
569
            return self.create_error_response(e)
570
571
572
573

    async def _collect_batch(
        self,
        ctx: ServeContext,
574
    ) -> ErrorResponse | None:
575
576
577
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
578
                return self.create_error_response("Engine prompts not available")
579
580

            num_prompts = len(ctx.engine_prompts)
581
            final_res_batch: list[PoolingRequestOutput | None]
582
583
584
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
585
                return self.create_error_response("Result generator not available")
586
587
588
589
590
591

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
592
593
                    "Failed to generate results for all prompts"
                )
594

595
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
596
597
598
599

            return None

        except Exception as e:
600
            return self.create_error_response(e)
601

602
    def create_error_response(
603
        self,
604
        message: str | Exception,
605
606
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
607
        param: str | None = None,
608
    ) -> ErrorResponse:
609
610
611
612
613
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

614
            from vllm.exceptions import VLLMValidationError
615
616
617
618
619

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
620
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
621
622
623
624
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
625
626
627
628
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
629
630
631
632
633
634
635
636
637
638
639
640
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

641
642
643
644
645
646
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
647

648
        return ErrorResponse(
649
            error=ErrorInfo(
650
                message=sanitize_message(message),
651
652
653
654
                type=err_type,
                code=status_code.value,
                param=param,
            )
655
        )
656

657
    def create_streaming_error_response(
658
        self,
659
        message: str | Exception,
660
661
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
662
        param: str | None = None,
663
    ) -> str:
664
        json_str = json.dumps(
665
            self.create_error_response(
666
667
668
669
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
670
671
            ).model_dump()
        )
672
673
        return json_str

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

703
    async def _check_model(
704
705
        self,
        request: AnyRequest,
706
    ) -> ErrorResponse | None:
707
708
        error_response = None

709
        if self._is_model_supported(request.model):
710
            return None
711
        if request.model in self.models.lora_requests:
712
            return None
713
714
715
716
717
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
718
719
            if isinstance(load_result, LoRARequest):
                return None
720
721
722
723
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
724
725
726
                error_response = load_result

        return error_response or self.create_error_response(
727
728
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
729
            status_code=HTTPStatus.NOT_FOUND,
730
            param="model",
731
        )
732

733
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

756
    def _maybe_get_adapters(
757
758
759
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
760
    ) -> LoRARequest | None:
761
        if request.model in self.models.lora_requests:
762
            return self.models.lora_requests[request.model]
763
764
765
766
767
768

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
769
                return default_mm_lora
770
771

        if self._is_model_supported(request.model):
772
            return None
773

774
        # if _check_model has been called earlier, this will be unreachable
775
        raise ValueError(f"The model `{request.model}` does not exist.")
776

777
778
779
780
781
782
783
784
785
786
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

787
788
789
790
791
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
792
793
794
795
796
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
797
798
799
800
801
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

802
803
    def _validate_input(
        self,
804
        request: object,
805
        input_ids: list[int],
806
        input_text: str,
807
    ) -> TokensPrompt:
808
        token_num = len(input_ids)
809
        max_model_len = self.model_config.max_model_len
810

811
        # Note: EmbeddingRequest,
812
        # and ScoreRequest doesn't have max_tokens
813
        if isinstance(
814
            request,
815
816
817
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
818
819
820
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
821
822
823
                RerankRequest,
            ),
        ):
824
825
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
826
            if token_num > max_model_len:
827
                operations: dict[type[AnyRequest], str] = {
828
829
830
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
831
                }
832
                operation = operations.get(type(request), "embedding generation")
833
                raise VLLMValidationError(
834
                    f"This model's maximum context length is "
835
                    f"{max_model_len} tokens. However, you requested "
836
                    f"{token_num} tokens in the input for {operation}. "
837
838
839
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
840
                )
841
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
842

843
844
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
845
        if isinstance(
846
847
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
848
        ):
849
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
850

851
852
853
854
855
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
856
            max_tokens = getattr(request, "max_tokens", None)
857
858
859

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
860
        if token_num >= max_model_len:
861
            raise VLLMValidationError(
862
                f"This model's maximum context length is "
863
                f"{max_model_len} tokens. However, your request has "
864
                f"{token_num} input tokens. Please reduce the length of "
865
866
867
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
868
            )
869

870
        if max_tokens is not None and token_num + max_tokens > max_model_len:
871
            raise VLLMValidationError(
872
873
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
874
875
                f"{max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {max_model_len}"
876
877
878
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
879
            )
880

881
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
882

883
884
    def _validate_chat_template(
        self,
885
886
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
887
        trust_request_chat_template: bool,
888
    ) -> ErrorResponse | None:
889
        if not trust_request_chat_template and (
890
891
892
893
894
895
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
896
897
898
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
899
900
                "Refused request with untrusted chat template."
            )
901
902
        return None

903
904
905
906
907
908
909
910
911
912
913
914
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

915
916
917
918
919
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
920
    ) -> list[ProcessorInputs]:
921
922
923
924
925
926
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

927
928
929
930
931
932
        return await self._preprocess_cmpl(request, prompts)

    async def _preprocess_cmpl(
        self,
        request: RendererRequest,
        prompts: Sequence[PromptType | bytes],
933
    ) -> list[ProcessorInputs]:
934
935
936
        renderer = self.renderer
        model_config = self.model_config

937
938
939
940
941
942
943
944
        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
945
        tok_params = request.build_tok_params(model_config)
946

947
948
949
950
951
952
953
954
955
        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )
956

957
958
    async def _preprocess_chat(
        self,
959
        request: RendererChatRequest,
960
        messages: list[ChatCompletionMessageParam],
961
962
963
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
964
        tool_dicts: list[dict[str, Any]] | None = None,
965
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
966
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
967
968
969
970
971
972
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
973
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
974
975
976
            ),
        )

977
978
979
980
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
        ).with_defaults(default_template_kwargs)
981

982
983
984
985
986
987
988
989
990
        (conversation,), (engine_prompt,) = await renderer.render_chat_async(
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
991
        )
992

993
994
995
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
996
997
998
999
1000
1001
1002
1003
1004
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
1005

1006
1007
1008
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
1009

1010
        return conversation, [engine_prompt]
1011

1012
    def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
1013
1014
        return extract_prompt_components(self.model_config, prompt)

1015
    def _extract_prompt_text(self, prompt: ProcessorInputs):
1016
1017
        return self._extract_prompt_components(prompt).text

1018
    def _extract_prompt_len(self, prompt: ProcessorInputs):
1019
1020
        return extract_prompt_len(self.model_config, prompt)

1021
1022
1023
1024
1025
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
1026
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
1027
1028
1029
1030
1031
1032
1033
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1034
        _, engine_prompts = await self._preprocess_chat(
1035
1036
            request,
            new_messages,
1037
1038
1039
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
1040
1041
1042
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
1043
        return engine_prompts
1044

1045
1046
1047
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1048
        engine_prompt: ProcessorInputs,
1049
1050
        sampling_params: SamplingParams,
        context: ConversationContext,
1051
        lora_request: LoRARequest | None = None,
1052
        priority: int = 0,
1053
        trace_headers: Mapping[str, str] | None = None,
1054
    ):
1055
        max_model_len = self.model_config.max_model_len
1056

1057
        orig_priority = priority
1058
        sub_request = 0
1059
        while True:
1060
1061
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1062

1063
            self._log_inputs(
1064
                sub_request_id,
1065
                engine_prompt,
1066
1067
1068
                params=sampling_params,
                lora_request=lora_request,
            )
1069

1070
            generator = self.engine_client.generate(
1071
                engine_prompt,
1072
                sampling_params,
1073
                sub_request_id,
1074
                lora_request=lora_request,
1075
                trace_headers=trace_headers,
1076
1077
                priority=priority,
            )
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1090
            context.append_tool_output(tool_output)
1091
1092
1093
1094
1095

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1096
            # Render the next prompt token ids and update sampling_params.
1097
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1098
                token_ids = context.render_for_completion()
1099
                engine_prompt = token_inputs(token_ids)
1100

1101
                sampling_params.max_tokens = max_model_len - len(token_ids)
1102
            elif isinstance(context, ParsableContext):
1103
                (engine_prompt,) = await self._render_next_turn(
1104
1105
1106
1107
1108
1109
1110
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
1111
1112

                sampling_params.max_tokens = get_max_tokens(
1113
                    max_model_len,
1114
                    context.request.max_output_tokens,
1115
                    self._extract_prompt_len(engine_prompt),
1116
                    self.default_sampling_params,  # type: ignore
1117
                    self.override_max_tokens,  # type: ignore
1118
                )
1119

1120
1121
            # OPTIMIZATION
            priority = orig_priority - 1
1122
            sub_request += 1
1123

1124
1125
1126
    def _log_inputs(
        self,
        request_id: str,
1127
        inputs: PromptType | ProcessorInputs,
1128
1129
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1130
1131
1132
    ) -> None:
        if self.request_logger is None:
            return
1133

1134
        components = self._extract_prompt_components(inputs)
1135
1136
1137

        self.request_logger.log_inputs(
            request_id,
1138
1139
1140
            components.text,
            components.token_ids,
            components.embeds,
1141
1142
1143
            params=params,
            lora_request=lora_request,
        )
1144

1145
1146
1147
    async def _get_trace_headers(
        self,
        headers: Headers,
1148
    ) -> Mapping[str, str] | None:
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1159
    @staticmethod
1160
    def _base_request_id(
1161
1162
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1163
        """Pulls the request id to use from a header, if provided"""
1164
1165
1166
1167
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1168

1169
        return random_uuid() if default is None else default
1170

1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1186
1187
1188
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1189
        tokenizer: TokenizerLike | None,
1190
        enable_auto_tools: bool,
1191
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1229
1230
1231
1232
1233
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1248
                        id=tool_call.id,
1249
1250
1251
1252
1253
1254
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1255
1256
                if content and content.strip() == "":
                    content = None
1257
1258
1259
1260
1261
1262
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1263
    @staticmethod
1264
1265
1266
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1267
        tokenizer: TokenizerLike | None,
1268
1269
        return_as_token_id: bool = False,
    ) -> str:
1270
1271
1272
        if return_as_token_id:
            return f"token_id:{token_id}"

1273
1274
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1275
1276
1277
1278
1279
1280

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1281
        return tokenizer.decode([token_id])
1282

1283
    def _is_model_supported(self, model_name: str | None) -> bool:
1284
1285
        if not model_name:
            return True
1286
        return self.models.is_base_model(model_name)
1287

1288
1289

def clamp_prompt_logprobs(
1290
1291
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1292
1293
1294
1295
1296
1297
1298
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1299
            if logprob_values.logprob == float("-inf"):
1300
1301
                logprob_values.logprob = -9999.0
    return prompt_logprobs