serving.py 25.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import contextlib
5
import json
6
import time
7
from collections.abc import AsyncGenerator, Mapping
8
from dataclasses import dataclass, field
9
from http import HTTPStatus
10
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
11

12
import numpy as np
13
from fastapi import Request
14
from openai.types.responses import ToolChoiceFunction
15
from pydantic import ConfigDict, TypeAdapter, ValidationError
16
from starlette.datastructures import Headers
17

18
import vllm.envs as envs
19
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
20
from vllm.config import ModelConfig
21
from vllm.engine.protocol import EngineClient
22
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
23
from vllm.entrypoints.logger import RequestLogger
24
from vllm.entrypoints.openai.chat_completion.protocol import (
25
    BatchChatCompletionRequest,
26
    ChatCompletionNamedToolChoiceParam,
27
28
    ChatCompletionRequest,
    ChatCompletionResponse,
29
)
30
from vllm.entrypoints.openai.completion.protocol import (
31
32
    CompletionRequest,
    CompletionResponse,
33
34
)
from vllm.entrypoints.openai.engine.protocol import (
35
    ErrorResponse,
36
    FunctionCall,
37
    FunctionDefinition,
38
    GenerationError,
39
)
40
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
41
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
42
from vllm.entrypoints.openai.speech_to_text.protocol import (
43
44
45
46
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
47
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
48
49
50
51
52
53
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
54
from vllm.entrypoints.utils import create_error_response
55
from vllm.inputs import EngineInput, PromptType
56
from vllm.logger import init_logger
57
from vllm.logprobs import Logprob, PromptLogprobs
58
from vllm.lora.request import LoRARequest
59
from vllm.outputs import CompletionOutput, RequestOutput
60
from vllm.renderers import ChatParams, TokenizeParams
61
62
63
64
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
)
65
from vllm.sampling_params import BeamSearchParams, SamplingParams
66
from vllm.tokenizers import TokenizerLike
67
from vllm.tool_parsers import ToolParser
68
69
70
71
72
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
73
from vllm.utils import random_uuid
74
from vllm.utils.async_utils import collect_from_async_generator
75
76
77

logger = init_logger(__name__)

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


93
CompletionLikeRequest: TypeAlias = (
94
    CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest
95
)
96

97
ChatLikeRequest: TypeAlias = (
98
    ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest
99
)
100

101
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
102

103
104
105
106
107
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
108
    | GenerateRequest
109
110
111
112
113
114
115
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
116
    | GenerateResponse
117
)
118
119
120
121

RequestT = TypeVar("RequestT", bound=AnyRequest)


122
@dataclass(kw_only=True)
123
class ServeContext(Generic[RequestT]):
124
    request: RequestT
125
    raw_request: Request | None = None
126
127
    model_name: str
    request_id: str
128
    created_time: int = field(default_factory=lambda: int(time.time()))
129
    lora_request: LoRARequest | None = None
130
    engine_inputs: list[EngineInput] | None = None
131
    model_config = ConfigDict(arbitrary_types_allowed=True)
132
133


134
class OpenAIServing:
135
    request_id_prefix: ClassVar[str] = """
136
    A short string prepended to every request’s ID.
137
    """
138

139
140
    def __init__(
        self,
141
        engine_client: EngineClient,
142
        models: OpenAIServingModels,
143
        *,
144
        request_logger: RequestLogger | None,
145
        return_tokens_as_token_ids: bool = False,
146
    ):
147
148
        super().__init__()

149
        self.engine_client = engine_client
150
        self.models = models
151

152
        self.request_logger = request_logger
153
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
154

155
156
157
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.input_processor = engine_client.input_processor
158
159
160

    async def beam_search(
        self,
161
        prompt: EngineInput,
162
163
        request_id: str,
        params: BeamSearchParams,
164
        lora_request: LoRARequest | None = None,
165
        trace_headers: Mapping[str, str] | None = None,
166
167
168
169
170
171
172
173
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

174
175
176
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
177

178
179
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
180

181
182
183
184
185
186
187
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

188
189
        tokenized_length = len(prompt_token_ids)

190
        logprobs_num = 2 * beam_width
191
        sampling_params = SamplingParams(
192
            logprobs=logprobs_num,
193
194
195
196
197
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
198
                orig_prompt=prompt,
199
200
201
202
203
204
205
206
207
208
209
210
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

211
212
213
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
214
215
216
217
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
218
219
                            prompt_item,
                            sampling_params,
220
                            request_id_item,
221
                            lora_request=lora_request_item,
222
                            trace_headers=trace_headers,
223
224
225
226
227
228
229
230
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
231
232
233
234
235
236
237
238
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

262
263
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
286
                            orig_prompt=prompt,
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
314
                        orig_prompt=prompt,
315
316
317
318
319
320
321
322
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
357

358
    @staticmethod
359
    def create_error_response(
360
        message: str | Exception,
361
362
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
363
        param: str | None = None,
364
    ) -> ErrorResponse:
365
        return create_error_response(message, err_type, status_code, param)
366

367
    def create_streaming_error_response(
368
        self,
369
        message: str | Exception,
370
371
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
372
        param: str | None = None,
373
    ) -> str:
374
        json_str = json.dumps(
375
            self.create_error_response(
376
377
378
379
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
380
381
            ).model_dump()
        )
382
383
        return json_str

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

403
    async def _check_model(
404
405
        self,
        request: AnyRequest,
406
    ) -> ErrorResponse | None:
407
408
        error_response = None

409
        if self._is_model_supported(request.model):
410
            return None
411
        if request.model in self.models.lora_requests:
412
            return None
413
414
415
416
417
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
418
419
            if isinstance(load_result, LoRARequest):
                return None
420
421
422
423
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
424
425
426
                error_response = load_result

        return error_response or self.create_error_response(
427
428
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
429
            status_code=HTTPStatus.NOT_FOUND,
430
            param="model",
431
        )
432

433
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

456
    def _maybe_get_adapters(
457
458
459
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
460
    ) -> LoRARequest | None:
461
        if request.model in self.models.lora_requests:
462
            return self.models.lora_requests[request.model]
463
464
465
466
467
468

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
469
                return default_mm_lora
470
471

        if self._is_model_supported(request.model):
472
            return None
473

474
        # if _check_model has been called earlier, this will be unreachable
475
        raise ValueError(f"The model `{request.model}` does not exist.")
476

477
478
479
480
481
482
483
484
485
486
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

487
488
489
490
491
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
492
493
494
495
496
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
497
498
499
500
501
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

502
503
    def _validate_chat_template(
        self,
504
505
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
506
        trust_request_chat_template: bool,
507
    ) -> ErrorResponse | None:
508
        if not trust_request_chat_template and (
509
510
511
512
513
514
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
515
516
517
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
518
519
                "Refused request with untrusted chat template."
            )
520
521
        return None

522
523
524
525
526
527
528
529
530
531
532
533
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

534
    def _extract_prompt_components(self, prompt: PromptType | EngineInput):
535
536
        return extract_prompt_components(self.model_config, prompt)

537
    def _extract_prompt_text(self, prompt: PromptType | EngineInput):
538
539
        return self._extract_prompt_components(prompt).text

540
    def _extract_prompt_len(self, prompt: EngineInput):
541
542
        return extract_prompt_len(self.model_config, prompt)

543
544
545
    def _log_inputs(
        self,
        request_id: str,
546
        inputs: PromptType | EngineInput,
547
        params: SamplingParams | BeamSearchParams | None,
548
        lora_request: LoRARequest | None,
549
550
551
    ) -> None:
        if self.request_logger is None:
            return
552

553
        components = self._extract_prompt_components(inputs)
554
555
556

        self.request_logger.log_inputs(
            request_id,
557
558
559
            components.text,
            components.token_ids,
            components.embeds,
560
561
562
            params=params,
            lora_request=lora_request,
        )
563

564
565
566
    async def _get_trace_headers(
        self,
        headers: Headers,
567
    ) -> Mapping[str, str] | None:
568
569
570
571
572
573
574
575
576
577
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

578
    @staticmethod
579
    def _base_request_id(
580
581
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
582
        """Pulls the request id to use from a header, if provided"""
583
584
585
586
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
587

588
        return random_uuid() if default is None else default
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

605
606
607
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
608
        tokenizer: TokenizerLike | None,
609
        enable_auto_tools: bool,
610
        tool_parser_cls: type[ToolParser] | None,
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
631
632
633
634
635
636
637
638
            tool_calls = []
            with contextlib.suppress(ValidationError):
                content = content or ""
                tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
                    content
                )
            for tool_call in tool_calls:
                function_calls.append(
639
640
641
642
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
643
                )
644
645
646
647
648
649
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
650
651
652
653
654
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

655
656
            # Automatic Tool Call Parsing
            try:
657
                tool_parser = tool_parser_cls(tokenizer, request.tools)
658
659
660
661
662
663
664
665
666
667
668
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
669
                        id=tool_call.id,
670
671
672
673
674
675
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
676
677
                if content and content.strip() == "":
                    content = None
678
679
680
681
682
683
            else:
                # No tool calls.
                return None, content

        return function_calls, content

684
    @staticmethod
685
686
687
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
688
        tokenizer: TokenizerLike | None,
689
690
        return_as_token_id: bool = False,
    ) -> str:
691
692
693
        if return_as_token_id:
            return f"token_id:{token_id}"

694
695
        if logprob.decoded_token is not None:
            return logprob.decoded_token
696
697
698
699
700
701

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

702
        return tokenizer.decode([token_id])
703

704
    def _is_model_supported(self, model_name: str | None) -> bool:
705
706
        if not model_name:
            return True
707
        return self.models.is_base_model(model_name)
708

709
710

def clamp_prompt_logprobs(
711
712
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
713
714
715
716
717
718
719
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
720
            if logprob_values.logprob == float("-inf"):
721
722
                logprob_values.logprob = -9999.0
    return prompt_logprobs