serving.py 27.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import contextlib
5
import json
6
import time
7
from collections.abc import AsyncGenerator, Mapping
8
from dataclasses import dataclass, field
9
from http import HTTPStatus
10
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
11

12
import numpy as np
13
from fastapi import Request
14
from openai.types.responses import ToolChoiceFunction
15
from pydantic import ConfigDict, TypeAdapter, ValidationError
16
from starlette.datastructures import Headers
17

18
import vllm.envs as envs
19
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
20
from vllm.config import ModelConfig
21
from vllm.engine.protocol import EngineClient
22
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
23
from vllm.entrypoints.logger import RequestLogger
24
from vllm.entrypoints.openai.chat_completion.protocol import (
25
    BatchChatCompletionRequest,
26
    ChatCompletionNamedToolChoiceParam,
27
28
    ChatCompletionRequest,
    ChatCompletionResponse,
29
)
30
from vllm.entrypoints.openai.completion.protocol import (
31
32
    CompletionRequest,
    CompletionResponse,
33
34
)
from vllm.entrypoints.openai.engine.protocol import (
35
    ErrorResponse,
36
    FunctionCall,
37
    FunctionDefinition,
38
    GenerationError,
39
)
40
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
41
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
42
from vllm.entrypoints.openai.speech_to_text.protocol import (
43
44
45
46
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
47
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
48
49
50
51
52
53
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
54
from vllm.entrypoints.utils import create_error_response
55
from vllm.inputs import EngineInput, PromptType
56
from vllm.logger import init_logger
57
from vllm.logprobs import Logprob, PromptLogprobs
58
from vllm.lora.request import LoRARequest
59
from vllm.outputs import CompletionOutput, RequestOutput
60
from vllm.renderers import ChatParams, TokenizeParams
61
62
63
64
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
)
65
from vllm.sampling_params import BeamSearchParams, SamplingParams
66
from vllm.tokenizers import TokenizerLike
67
from vllm.tool_parsers import ToolParser
68
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
69
70
71
72
73
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
74
from vllm.utils import random_uuid
75
from vllm.utils.async_utils import collect_from_async_generator
76
77
78

logger = init_logger(__name__)

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


94
CompletionLikeRequest: TypeAlias = (
95
    CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest
96
)
97

98
ChatLikeRequest: TypeAlias = (
99
    ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest
100
)
101

102
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
103

104
105
106
107
108
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
109
    | GenerateRequest
110
111
112
113
114
115
116
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
117
    | GenerateResponse
118
)
119
120
121
122

RequestT = TypeVar("RequestT", bound=AnyRequest)


123
@dataclass(kw_only=True)
124
class ServeContext(Generic[RequestT]):
125
    request: RequestT
126
    raw_request: Request | None = None
127
128
    model_name: str
    request_id: str
129
    created_time: int = field(default_factory=lambda: int(time.time()))
130
    lora_request: LoRARequest | None = None
131
    engine_inputs: list[EngineInput] | None = None
132
    model_config = ConfigDict(arbitrary_types_allowed=True)
133
134


135
class OpenAIServing:
136
    request_id_prefix: ClassVar[str] = """
137
    A short string prepended to every request’s ID.
138
    """
139

140
141
    def __init__(
        self,
142
        engine_client: EngineClient,
143
        models: OpenAIServingModels,
144
        *,
145
        request_logger: RequestLogger | None,
146
        return_tokens_as_token_ids: bool = False,
147
    ):
148
149
        super().__init__()

150
        self.engine_client = engine_client
151
        self.models = models
152

153
        self.request_logger = request_logger
154
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
155

156
157
158
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.input_processor = engine_client.input_processor
159
160
161

    async def beam_search(
        self,
162
        prompt: EngineInput,
163
164
        request_id: str,
        params: BeamSearchParams,
165
        lora_request: LoRARequest | None = None,
166
        trace_headers: Mapping[str, str] | None = None,
167
168
169
170
171
172
173
174
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

175
176
177
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
178

179
180
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
181

182
183
184
185
186
187
188
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

189
190
        tokenized_length = len(prompt_token_ids)

191
        logprobs_num = 2 * beam_width
192
        sampling_params = SamplingParams(
193
            logprobs=logprobs_num,
194
195
196
197
198
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
199
                orig_prompt=prompt,
200
201
202
203
204
205
206
207
208
209
210
211
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

212
213
214
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
215
216
217
218
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
219
220
                            prompt_item,
                            sampling_params,
221
                            request_id_item,
222
                            lora_request=lora_request_item,
223
                            trace_headers=trace_headers,
224
225
226
227
228
229
230
231
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
232
233
234
235
236
237
238
239
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

263
264
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
287
                            orig_prompt=prompt,
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
315
                        orig_prompt=prompt,
316
317
318
319
320
321
322
323
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
358

359
    @staticmethod
360
    def create_error_response(
361
        message: str | Exception,
362
363
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
364
        param: str | None = None,
365
    ) -> ErrorResponse:
366
        return create_error_response(message, err_type, status_code, param)
367

368
    def create_streaming_error_response(
369
        self,
370
        message: str | Exception,
371
372
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
373
        param: str | None = None,
374
    ) -> str:
375
        json_str = json.dumps(
376
            self.create_error_response(
377
378
379
380
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
381
382
            ).model_dump()
        )
383
384
        return json_str

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

404
    async def _check_model(
405
406
        self,
        request: AnyRequest,
407
    ) -> ErrorResponse | None:
408
409
        error_response = None

410
        if self._is_model_supported(request.model):
411
            return None
412
        if request.model in self.models.lora_requests:
413
            return None
414
415
416
417
418
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
419
420
            if isinstance(load_result, LoRARequest):
                return None
421
422
423
424
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
425
426
427
                error_response = load_result

        return error_response or self.create_error_response(
428
429
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
430
            status_code=HTTPStatus.NOT_FOUND,
431
            param="model",
432
        )
433

434
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

457
    def _maybe_get_adapters(
458
459
460
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
461
    ) -> LoRARequest | None:
462
        if request.model in self.models.lora_requests:
463
            return self.models.lora_requests[request.model]
464
465
466
467
468
469

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
470
                return default_mm_lora
471
472

        if self._is_model_supported(request.model):
473
            return None
474

475
        # if _check_model has been called earlier, this will be unreachable
476
        raise ValueError(f"The model `{request.model}` does not exist.")
477

478
479
480
481
482
483
484
485
486
487
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

488
489
490
491
492
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
493
494
495
496
497
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
498
499
500
501
502
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

503
504
    def _validate_chat_template(
        self,
505
506
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
507
        trust_request_chat_template: bool,
508
    ) -> ErrorResponse | None:
509
        if not trust_request_chat_template and (
510
511
512
513
514
515
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
516
517
518
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
519
520
                "Refused request with untrusted chat template."
            )
521
522
        return None

523
524
525
526
527
528
529
530
531
532
533
534
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

535
    def _extract_prompt_components(self, prompt: PromptType | EngineInput):
536
537
        return extract_prompt_components(self.model_config, prompt)

538
    def _extract_prompt_text(self, prompt: PromptType | EngineInput):
539
540
        return self._extract_prompt_components(prompt).text

541
    def _extract_prompt_len(self, prompt: EngineInput):
542
543
        return extract_prompt_len(self.model_config, prompt)

544
545
546
    def _log_inputs(
        self,
        request_id: str,
547
        inputs: PromptType | EngineInput,
548
        params: SamplingParams | BeamSearchParams | None,
549
        lora_request: LoRARequest | None,
550
551
552
    ) -> None:
        if self.request_logger is None:
            return
553

554
        components = self._extract_prompt_components(inputs)
555
556
557

        self.request_logger.log_inputs(
            request_id,
558
559
560
            components.text,
            components.token_ids,
            components.embeds,
561
562
563
            params=params,
            lora_request=lora_request,
        )
564

565
566
567
    async def _get_trace_headers(
        self,
        headers: Headers,
568
    ) -> Mapping[str, str] | None:
569
570
571
572
573
574
575
576
577
578
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

579
    @staticmethod
580
    def _base_request_id(
581
582
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
583
        """Pulls the request id to use from a header, if provided"""
584
585
586
587
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
588

589
        return random_uuid() if default is None else default
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

606
607
608
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
609
        tokenizer: TokenizerLike | None,
610
        enable_auto_tools: bool,
611
        tool_parser_cls: type[ToolParser] | None,
612
613
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
614
615
616
617
618
619
620
621
622
        # When the Mistral grammar factory injected structured outputs,
        # let the parser handle the output.
        use_mistral_tool_parser = (
            isinstance(request, ChatCompletionRequest)
            and tool_parser_cls is not None
            and issubclass(tool_parser_cls, MistralToolParser)
            and request._grammar_from_tool_parser
        )

623
        function_calls = list[FunctionCall]()
624
625
626
627
628
        if (
            not use_mistral_tool_parser
            and request.tool_choice
            and isinstance(request.tool_choice, ToolChoiceFunction)
        ):
629
            assert content is not None
630
            # Forced Function Call (Responses API)
631
632
633
634
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
635
636
637
638
        elif (
            not use_mistral_tool_parser
            and request.tool_choice
            and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
639
            and (tool_parser_cls is None or tool_parser_cls.supports_required_and_named)
640
        ):
641
            # Named function with standard JSON-based parsing
642
643
644
645
646
            assert content is not None
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
647
648
649
650
651
652
        elif (
            not use_mistral_tool_parser
            and request.tool_choice == "required"
            and (tool_parser_cls is None or tool_parser_cls.supports_required_and_named)
        ):
            # "required" with standard JSON-based parsing
653
654
655
656
657
658
659
660
            tool_calls = []
            with contextlib.suppress(ValidationError):
                content = content or ""
                tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
                    content
                )
            for tool_call in tool_calls:
                function_calls.append(
661
662
663
664
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
665
                )
666
            content = None  # Clear content since tool is called.
667
668
669
670
        elif tool_parser_cls and (
            use_mistral_tool_parser
            or (
                enable_auto_tools
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                and (
                    request.tool_choice == "auto"
                    or request.tool_choice is None
                    or (
                        not tool_parser_cls.supports_required_and_named
                        and request.tools
                        and (
                            request.tool_choice == "required"
                            or isinstance(
                                request.tool_choice,
                                ChatCompletionNamedToolChoiceParam,
                            )
                        )
                    )
                )
686
            )
687
        ):
688
689
            # Automatic Tool Call Parsing (also used as fallback for
            # required/named when supports_required_and_named=False)
690
691
692
693
694
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

695
            try:
696
                tool_parser = tool_parser_cls(tokenizer, request.tools)
697
698
699
700
701
702
703
704
705
706
707
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
708
                        id=tool_call.id,
709
710
711
712
713
714
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
715
716
                if content and content.strip() == "":
                    content = None
717
718
719
720
721
722
            else:
                # No tool calls.
                return None, content

        return function_calls, content

723
    @staticmethod
724
725
726
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
727
        tokenizer: TokenizerLike | None,
728
729
        return_as_token_id: bool = False,
    ) -> str:
730
731
732
        if return_as_token_id:
            return f"token_id:{token_id}"

733
734
        if logprob.decoded_token is not None:
            return logprob.decoded_token
735
736
737
738
739
740

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

741
        return tokenizer.decode([token_id])
742

743
    def _is_model_supported(self, model_name: str | None) -> bool:
744
745
        if not model_name:
            return True
746
        return self.models.is_base_model(model_name)
747

748
749

def clamp_prompt_logprobs(
750
751
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
752
753
754
755
756
757
758
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
759
            if logprob_values.logprob == float("-inf"):
760
761
                logprob_values.logprob = -9999.0
    return prompt_logprobs