serving.py 31.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import contextlib
5
import json
6
import time
7
from collections.abc import AsyncGenerator, Mapping
8
from dataclasses import dataclass, field
9
from http import HTTPStatus
10
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
11

12
import numpy as np
13
from fastapi import Request
14
from openai.types.responses import ToolChoiceFunction
15
from pydantic import ConfigDict, TypeAdapter, ValidationError
16
from starlette.datastructures import Headers
17

18
import vllm.envs as envs
19
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
20
from vllm.config import ModelConfig
21
from vllm.engine.protocol import EngineClient
22
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
23
from vllm.entrypoints.logger import RequestLogger
24
from vllm.entrypoints.openai.chat_completion.protocol import (
25
    BatchChatCompletionRequest,
26
    ChatCompletionNamedToolChoiceParam,
27
28
    ChatCompletionRequest,
    ChatCompletionResponse,
29
)
30
from vllm.entrypoints.openai.completion.protocol import (
31
32
    CompletionRequest,
    CompletionResponse,
33
34
)
from vllm.entrypoints.openai.engine.protocol import (
35
    ErrorResponse,
36
    FunctionCall,
37
    FunctionDefinition,
38
    GenerationError,
39
)
40
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
41
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
42
from vllm.entrypoints.openai.speech_to_text.protocol import (
43
44
45
46
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
47
48
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
49
50
    PoolingChatRequest,
    PoolingCompletionRequest,
51
52
    PoolingResponse,
)
53
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
54
55
56
57
58
59
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
60
from vllm.entrypoints.utils import create_error_response
61
from vllm.inputs import EngineInput, PromptType
62
from vllm.logger import init_logger
63
from vllm.logprobs import Logprob, PromptLogprobs
64
from vllm.lora.request import LoRARequest
65
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
66
from vllm.pooling_params import PoolingParams
67
from vllm.renderers import ChatParams, TokenizeParams
68
69
70
71
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
)
72
from vllm.sampling_params import BeamSearchParams, SamplingParams
73
from vllm.tokenizers import TokenizerLike
74
from vllm.tool_parsers import ToolParser
75
76
77
78
79
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
80
from vllm.utils import random_uuid
81
from vllm.utils.async_utils import (
82
    collect_from_async_generator,
83
84
    merge_async_iterators,
)
85
86
87

logger = init_logger(__name__)

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


103
104
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
105
    | TokenizeCompletionRequest
106
    | DetokenizeRequest
107
    | PoolingCompletionRequest
108
)
109

110
ChatLikeRequest: TypeAlias = (
111
112
113
114
    ChatCompletionRequest
    | BatchChatCompletionRequest
    | TokenizeChatRequest
    | PoolingChatRequest
115
)
116

117
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
118

119
120
121
122
123
124
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
125
    | GenerateRequest
126
127
128
129
130
131
132
133
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
134
    | GenerateResponse
135
)
136
137
138
139

RequestT = TypeVar("RequestT", bound=AnyRequest)


140
@dataclass(kw_only=True)
141
class ServeContext(Generic[RequestT]):
142
    request: RequestT
143
    raw_request: Request | None = None
144
145
    model_name: str
    request_id: str
146
    created_time: int = field(default_factory=lambda: int(time.time()))
147
    lora_request: LoRARequest | None = None
148
    engine_inputs: list[EngineInput] | None = None
149

150
151
152
153
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
154

155
    model_config = ConfigDict(arbitrary_types_allowed=True)
156
157


158
class OpenAIServing:
159
    request_id_prefix: ClassVar[str] = """
160
    A short string prepended to every request’s ID.
161
    """
162

163
164
    def __init__(
        self,
165
        engine_client: EngineClient,
166
        models: OpenAIServingModels,
167
        *,
168
        request_logger: RequestLogger | None,
169
        return_tokens_as_token_ids: bool = False,
170
    ):
171
172
        super().__init__()

173
        self.engine_client = engine_client
174

175
        self.models = models
176

177
        self.request_logger = request_logger
178
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
179

180
181
182
183
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
184
185
186

    async def beam_search(
        self,
187
        prompt: EngineInput,
188
189
        request_id: str,
        params: BeamSearchParams,
190
        lora_request: LoRARequest | None = None,
191
        trace_headers: Mapping[str, str] | None = None,
192
193
194
195
196
197
198
199
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

200
201
202
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
203

204
205
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
206

207
208
209
210
211
212
213
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

214
215
        tokenized_length = len(prompt_token_ids)

216
        logprobs_num = 2 * beam_width
217
        sampling_params = SamplingParams(
218
            logprobs=logprobs_num,
219
220
221
222
223
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
224
                orig_prompt=prompt,
225
226
227
228
229
230
231
232
233
234
235
236
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

237
238
239
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
240
241
242
243
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
244
245
                            prompt_item,
                            sampling_params,
246
                            request_id_item,
247
                            lora_request=lora_request_item,
248
                            trace_headers=trace_headers,
249
250
251
252
253
254
255
256
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
257
258
259
260
261
262
263
264
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

288
289
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
312
                            orig_prompt=prompt,
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
340
                        orig_prompt=prompt,
341
342
343
344
345
346
347
348
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
383

384
385
386
    async def _preprocess(
        self,
        ctx: ServeContext,
387
    ) -> ErrorResponse | None:
388
        """
389
        Default preprocessing hook. Subclasses may override to prepare `ctx`.
390
391
392
393
394
395
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
396
    ) -> AnyResponse | ErrorResponse:
397
398
399
400
401
402
403
404
405
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
406
    ) -> AnyResponse | ErrorResponse:
407
        async for response in self._pipeline(ctx):
408
409
410
411
412
413
414
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
415
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

436
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
437
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
438

439
440
        if (
            truncate_prompt_tokens is not None
441
            and truncate_prompt_tokens > self.model_config.max_model_len
442
        ):
443
444
445
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
446
                " Please request a smaller truncation size."
447
            )
448
449
        return None

450
451
452
    def _create_pooling_params(
        self,
        ctx: ServeContext,
453
    ) -> PoolingParams | ErrorResponse:
454
455
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
456
457
                "Request type does not support pooling parameters"
            )
458
459
460

        return ctx.request.to_pooling_params()

461
462
463
    async def _prepare_generators(
        self,
        ctx: ServeContext,
464
    ) -> ErrorResponse | None:
465
        """Schedule the request and get the result generator."""
466
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
467

468
469
470
471
472
        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )
473

474
475
476
        pooling_params = self._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
477

478
        if ctx.engine_inputs is None:
479
            return self.create_error_response("Engine prompts not available")
480

481
        for i, engine_input in enumerate(ctx.engine_inputs):
482
            request_id_item = f"{ctx.request_id}-{i}"
483

484
485
            self._log_inputs(
                request_id_item,
486
                engine_input,
487
488
489
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
490

491
            generator = self.engine_client.encode(
492
                engine_input,
493
494
495
496
497
498
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )
499

500
            generators.append(generator)
501

502
        ctx.result_generator = merge_async_iterators(*generators)
503

504
        return None
505
506
507
508

    async def _collect_batch(
        self,
        ctx: ServeContext,
509
    ) -> ErrorResponse | None:
510
        """Collect batch results from the result generator."""
511
        if ctx.engine_inputs is None:
512
            return self.create_error_response("Engine prompts not available")
513

514
        num_prompts = len(ctx.engine_inputs)
515
516
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
517

518
519
        if ctx.result_generator is None:
            return self.create_error_response("Result generator not available")
520

521
522
        async for i, res in ctx.result_generator:
            final_res_batch[i] = res
523

524
525
526
527
        if None in final_res_batch:
            return self.create_error_response(
                "Failed to generate results for all prompts"
            )
528

529
        ctx.final_res_batch = [res for res in final_res_batch if res is not None]
530

531
        return None
532

533
    @staticmethod
534
    def create_error_response(
535
        message: str | Exception,
536
537
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
538
        param: str | None = None,
539
    ) -> ErrorResponse:
540
        return create_error_response(message, err_type, status_code, param)
541

542
    def create_streaming_error_response(
543
        self,
544
        message: str | Exception,
545
546
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
547
        param: str | None = None,
548
    ) -> str:
549
        json_str = json.dumps(
550
            self.create_error_response(
551
552
553
554
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
555
556
            ).model_dump()
        )
557
558
        return json_str

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

578
    async def _check_model(
579
580
        self,
        request: AnyRequest,
581
    ) -> ErrorResponse | None:
582
583
        error_response = None

584
        if self._is_model_supported(request.model):
585
            return None
586
        if request.model in self.models.lora_requests:
587
            return None
588
589
590
591
592
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
593
594
            if isinstance(load_result, LoRARequest):
                return None
595
596
597
598
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
599
600
601
                error_response = load_result

        return error_response or self.create_error_response(
602
603
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
604
            status_code=HTTPStatus.NOT_FOUND,
605
            param="model",
606
        )
607

608
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

631
    def _maybe_get_adapters(
632
633
634
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
635
    ) -> LoRARequest | None:
636
        if request.model in self.models.lora_requests:
637
            return self.models.lora_requests[request.model]
638
639
640
641
642
643

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
644
                return default_mm_lora
645
646

        if self._is_model_supported(request.model):
647
            return None
648

649
        # if _check_model has been called earlier, this will be unreachable
650
        raise ValueError(f"The model `{request.model}` does not exist.")
651

652
653
654
655
656
657
658
659
660
661
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

662
663
664
665
666
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
667
668
669
670
671
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
672
673
674
675
676
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

677
678
    def _validate_chat_template(
        self,
679
680
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
681
        trust_request_chat_template: bool,
682
    ) -> ErrorResponse | None:
683
        if not trust_request_chat_template and (
684
685
686
687
688
689
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
690
691
692
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
693
694
                "Refused request with untrusted chat template."
            )
695
696
        return None

697
698
699
700
701
702
703
704
705
706
707
708
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

709
    def _extract_prompt_components(self, prompt: PromptType | EngineInput):
710
711
        return extract_prompt_components(self.model_config, prompt)

712
    def _extract_prompt_text(self, prompt: PromptType | EngineInput):
713
714
        return self._extract_prompt_components(prompt).text

715
    def _extract_prompt_len(self, prompt: EngineInput):
716
717
        return extract_prompt_len(self.model_config, prompt)

718
719
720
    def _log_inputs(
        self,
        request_id: str,
721
        inputs: PromptType | EngineInput,
722
723
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
724
725
726
    ) -> None:
        if self.request_logger is None:
            return
727

728
        components = self._extract_prompt_components(inputs)
729
730
731

        self.request_logger.log_inputs(
            request_id,
732
733
734
            components.text,
            components.token_ids,
            components.embeds,
735
736
737
            params=params,
            lora_request=lora_request,
        )
738

739
740
741
    async def _get_trace_headers(
        self,
        headers: Headers,
742
    ) -> Mapping[str, str] | None:
743
744
745
746
747
748
749
750
751
752
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

753
    @staticmethod
754
    def _base_request_id(
755
756
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
757
        """Pulls the request id to use from a header, if provided"""
758
759
760
761
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
762

763
        return random_uuid() if default is None else default
764

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

780
781
782
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
783
        tokenizer: TokenizerLike | None,
784
        enable_auto_tools: bool,
785
        tool_parser_cls: type[ToolParser] | None,
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
806
807
808
809
810
811
812
813
            tool_calls = []
            with contextlib.suppress(ValidationError):
                content = content or ""
                tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
                    content
                )
            for tool_call in tool_calls:
                function_calls.append(
814
815
816
817
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
818
                )
819
820
821
822
823
824
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
825
826
827
828
829
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

830
831
            # Automatic Tool Call Parsing
            try:
832
                tool_parser = tool_parser_cls(tokenizer, request.tools)
833
834
835
836
837
838
839
840
841
842
843
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
844
                        id=tool_call.id,
845
846
847
848
849
850
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
851
852
                if content and content.strip() == "":
                    content = None
853
854
855
856
857
858
            else:
                # No tool calls.
                return None, content

        return function_calls, content

859
    @staticmethod
860
861
862
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
863
        tokenizer: TokenizerLike | None,
864
865
        return_as_token_id: bool = False,
    ) -> str:
866
867
868
        if return_as_token_id:
            return f"token_id:{token_id}"

869
870
        if logprob.decoded_token is not None:
            return logprob.decoded_token
871
872
873
874
875
876

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

877
        return tokenizer.decode([token_id])
878

879
    def _is_model_supported(self, model_name: str | None) -> bool:
880
881
        if not model_name:
            return True
882
        return self.models.is_base_model(model_name)
883

884
885

def clamp_prompt_logprobs(
886
887
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
888
889
890
891
892
893
894
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
895
            if logprob_values.logprob == float("-inf"):
896
897
                logprob_values.logprob = -9999.0
    return prompt_logprobs