"tests/vscode:/vscode.git/clone" did not exist on "d93bf4da855a0c5e8d3c875def6b37c5e9d77763"
protocol.py 93 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
10
from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional,
                    TypeVar, Union)
Zhuohan Li's avatar
Zhuohan Li committed
11

12
import regex as re
13
import torch
14
from fastapi import HTTPException, UploadFile
15
16
17
18
19
20
# yapf: disable
from openai.types.chat.chat_completion_audio import (
    ChatCompletionAudio as OpenAIChatCompletionAudio)
from openai.types.chat.chat_completion_message import (
    Annotation as OpenAIAnnotation)
# yapf: enable
21
22
from openai.types.responses import (ResponseFunctionToolCall,
                                    ResponseInputItemParam, ResponseOutputItem,
23
                                    ResponsePrompt, ResponseReasoningItem,
24
25
26
27
28
29
30
31
32
                                    ResponseStatus)

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
    from openai.types.responses import (ResponseFormatTextConfig as
                                        ResponseTextConfig)

33
34
35
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
36
37
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
                      ValidationInfo, field_validator, model_validator)
38
from typing_extensions import TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
39

40
from vllm import envs
41
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
42
                                         make_tool_call_id)
43
44
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam)
45
from vllm.logger import init_logger
46
from vllm.logprobs import Logprob
47
from vllm.pooling_params import PoolingParams
48
49
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
50
from vllm.utils import random_uuid, resolve_obj_by_qualname
51

52
53
logger = init_logger(__name__)

54
_LONG_INFO = torch.iinfo(torch.long)
55

Zhuohan Li's avatar
Zhuohan Li committed
56

57
class OpenAIBaseModel(BaseModel):
58
59
60
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

61
    # Cache class field names
62
    field_names: ClassVar[Optional[set[str]]] = None
63

64
    @model_validator(mode="wrap")
65
    @classmethod
66
67
68
69
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
70
71
        field_names = cls.field_names
        if field_names is None:
72
73
74
75
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
76
                if alias := getattr(field, "alias", None):
77
78
79
80
81
82
83
84
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
                "The following fields were present in the request "
                "but ignored: %s",
85
86
                data.keys() - field_names,
            )
87
        return result
88
89


90
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
91
92
93
    message: str
    type: str
    param: Optional[str] = None
94
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
95
96


97
98
99
100
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


101
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
102
103
104
105
106
107
108
109
110
111
112
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
113
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
114
115


116
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
117
118
119
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
120
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
121
122
    root: Optional[str] = None
    parent: Optional[str] = None
123
    max_model_len: Optional[int] = None
124
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
125
126


127
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
128
    object: str = "list"
129
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
130
131


132
133
134
135
class PromptTokenUsageInfo(OpenAIBaseModel):
    cached_tokens: Optional[int] = None


136
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
137
138
139
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0
140
    prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Zhuohan Li's avatar
Zhuohan Li committed
141
142


143
144
145
146
147
class RequestResponseMetadata(BaseModel):
    request_id: str
    final_usage_info: Optional[UsageInfo] = None


148
149
150
151
152
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
153
    json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema')
154
155
156
    strict: Optional[bool] = None


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class StructuralTag(OpenAIBaseModel):
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
    structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
                                                            alias="schema")
    end: str


class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    structures: list[StructuralTag]
    triggers: list[str]


172
class ResponseFormat(OpenAIBaseModel):
173
    # type must be "json_schema", "json_object", or "text"
174
175
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchemaResponseFormat] = None
176
177


178
179
180
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]


181
class StreamOptions(OpenAIBaseModel):
182
    include_usage: Optional[bool] = True
183
    continuous_usage_stats: Optional[bool] = False
184
185


186
187
188
class FunctionDefinition(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
189
    parameters: Optional[dict[str, Any]] = None
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


206
207
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
208
209
class LogitsProcessorConstructor(BaseModel):
    qualname: str
210
211
    args: Optional[list[Any]] = None
    kwargs: Optional[dict[str, Any]] = None
212

213
214
    model_config = ConfigDict(extra="forbid")

215

216
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
217
218
219


def get_logits_processors(processors: Optional[LogitsProcessors],
220
                          pattern: Optional[str]) -> Optional[list[Any]]:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    if processors and pattern:
        logits_processors = []
        for processor in processors:
            qualname = processor if isinstance(processor,
                                               str) else processor.qualname
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
                    "for more information.")
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
                logits_processor = logits_processor(*processor.args or [],
                                                    **processor.kwargs or {})
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
            "server. See --logits-processor-pattern engine argugment "
            "for more information.")
    return None


250
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
251
                                           ResponseReasoningItem,
252
253
254
                                           ResponseFunctionToolCall]


255
256
257
258
259
260
261
262
263
264
265
266
267
268
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
    background: Optional[bool] = False
    include: Optional[list[
        Literal[
            "code_interpreter_call.outputs",
            "computer_call_output.output.image_url",
            "file_search_call.results",
            "message.input_image.image_url",
            "message.output_text.logprobs",
            "reasoning.encrypted_content",
        ],
    ]] = None
269
    input: Union[str, list[ResponseInputOutputItem]]
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    instructions: Optional[str] = None
    max_output_tokens: Optional[int] = None
    max_tool_calls: Optional[int] = None
    metadata: Optional[Metadata] = None
    model: Optional[str] = None
    parallel_tool_calls: Optional[bool] = True
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
    service_tier: Literal["auto", "default", "flex", "scale",
                          "priority"] = "auto"
    store: Optional[bool] = True
    stream: Optional[bool] = False
    temperature: Optional[float] = None
    text: Optional[ResponseTextConfig] = None
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
    top_logprobs: Optional[int] = 0
    top_p: Optional[float] = None
    truncation: Optional[Literal["auto", "disabled"]] = "disabled"
    user: Optional[str] = None

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."),
    )
311
312
313
314
315
316
317
318
319
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit). Not supported by vLLM engine V0."))
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
344
        stop_token_ids = default_sampling_params.get("stop_token_ids")
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        # Structured output
        guided_decoding = None
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
            if response_format.type == "json_schema":
                guided_decoding = GuidedDecodingParams.from_optional(
                    json=response_format.schema_)
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
361
362
            logprobs=self.top_logprobs
            if self.is_include_output_logprobs() else None,
363
            stop_token_ids=stop_token_ids,
364
365
366
367
368
            output_kind=(RequestOutputKind.DELTA
                         if self.stream else RequestOutputKind.FINAL_ONLY),
            guided_decoding=guided_decoding,
        )

369
370
371
372
373
374
375
376
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
        return isinstance(
            self.include,
            list) and "message.output_text.logprobs" in self.include

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
            raise ValueError(
                "background can only be used when `store` is true")
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

392
393
394
395
396
397
398
399
400
401
402
403
404
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
                    "this instance of vLLM, which uses engine V0.")
            if not isinstance(data["cache_salt"],
                              str) or not data["cache_salt"]:
                raise ValueError("Parameter 'cache_salt' must be a "
                                 "non-empty string if provided.")
        return data

405

406
class ChatCompletionRequest(OpenAIBaseModel):
407
408
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
409
    messages: list[ChatCompletionMessageParam]
410
    model: Optional[str] = None
411
    frequency_penalty: Optional[float] = 0.0
412
    logit_bias: Optional[dict[str, float]] = None
413
    logprobs: Optional[bool] = False
414
    top_logprobs: Optional[int] = 0
415
416
417
418
419
    max_tokens: Optional[int] = Field(
        default=None,
        deprecated=
        'max_tokens is deprecated in favor of the max_completion_tokens field')
    max_completion_tokens: Optional[int] = None
420
421
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
422
    response_format: Optional[AnyResponseFormat] = None
423
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
424
    stop: Optional[Union[str, list[str]]] = []
Zhuohan Li's avatar
Zhuohan Li committed
425
    stream: Optional[bool] = False
426
    stream_options: Optional[StreamOptions] = None
427
428
    temperature: Optional[float] = None
    top_p: Optional[float] = None
429
    tools: Optional[list[ChatCompletionToolsParam]] = None
430
431
432
433
434
435
    tool_choice: Optional[Union[
        Literal["none"],
        Literal["auto"],
        Literal["required"],
        ChatCompletionNamedToolChoiceParam,
    ]] = "none"
436
437
    reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
    include_reasoning: bool = True
438

439
    # NOTE this will be ignored by vLLM -- the model determines the behavior
440
    parallel_tool_calls: Optional[bool] = False
Zhuohan Li's avatar
Zhuohan Li committed
441
    user: Optional[str] = None
442

443
    # --8<-- [start:chat-completion-sampling-params]
444
    best_of: Optional[int] = None
445
    use_beam_search: bool = False
446
447
448
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
449
    length_penalty: float = 1.0
450
    stop_token_ids: Optional[list[int]] = []
451
452
453
454
455
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
456
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
457
    prompt_logprobs: Optional[int] = None
458
    allowed_token_ids: Optional[list[int]] = None
459
    bad_words: list[str] = Field(default_factory=list)
460
    # --8<-- [end:chat-completion-sampling-params]
461

462
    # --8<-- [start:chat-completion-extra-params]
463
    echo: bool = Field(
464
465
466
467
468
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
            "if they belong to the same role."),
    )
469
    add_generation_prompt: bool = Field(
470
471
472
473
474
475
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
476
477
478
479
480
481
482
483
484
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
485
    add_special_tokens: bool = Field(
486
487
488
489
490
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
491
            "special tokens so this should be set to false (as is the "
492
493
            "default)."),
    )
494
    documents: Optional[list[dict[str, str]]] = Field(
495
496
497
498
499
500
501
502
503
504
505
506
        default=None,
        description=
        ("A list of dicts representing documents that will be accessible to "
         "the model if it is performing RAG (retrieval-augmented generation)."
         " If the template does not support RAG, this argument will have no "
         "effect. We recommend that each document should be a dict containing "
         "\"title\" and \"text\" keys."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
507
508
509
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
510
    )
511
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
512
        default=None,
513
514
515
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."),
516
    )
517
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
518
519
520
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
521
522
523
524
525
526
527
528
529
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=("If specified, the output will follow the JSON schema."),
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
530
    guided_choice: Optional[list[str]] = Field(
531
532
533
534
535
536
537
538
539
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
540
541
542
543
544
    structural_tag: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the structural tag schema."),
    )
545
546
547
548
549
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be either "
550
551
            "'outlines' / 'lm-format-enforcer'"),
    )
552
553
554
555
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
556
557
            "for guided json decoding."),
    )
558
559
560
561
562
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
563
564
            "if the served model does not use priority scheduling."),
    )
565
566
567
568
569
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
570
571
            "through out the inference process and return in response."),
    )
572
573
574
575
576
577
578
579
580
581
582
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
            "{'param': 'value'}}."))
583
584
585
586
587
588
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
            "that are not JSON-encodable can be identified."))
589
590
591
592
593
594
595
596
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
            "need to map generated text back to input tokens."))
597
598
599
600
601
602
603
604
605
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit). Not supported by vLLM engine V0."))
Robert Shaw's avatar
Robert Shaw committed
606
607
608
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.")
609

610
611
612
613
614
615
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
        description=("Additional request parameters with string or "
                     "numeric values, used by custom extensions."),
    )

616
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
617

618
619
620
621
622
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
623
        "top_k": 0,
624
625
626
627
        "min_p": 0.0,
    }

    def to_beam_search_params(
628
629
            self, max_tokens: int,
            default_sampling_params: dict) -> BeamSearchParams:
630
631

        n = self.n if self.n is not None else 1
632
633
634
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
635
636
637
638
639
640

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
641
            length_penalty=self.length_penalty,
642
643
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
644

645
    def to_sampling_params(
646
        self,
647
        max_tokens: int,
648
        logits_processor_pattern: Optional[str],
649
        default_sampling_params: dict,
650
    ) -> SamplingParams:
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])

671
672
673
674
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

675
        guided_json_object = None
676
677
678
679
680
681
682
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
683
684
685
686
687
688
            elif self.response_format.type == "structural_tag":
                structural_tag = self.response_format
                assert structural_tag is not None and isinstance(
                    structural_tag, StructuralTagResponseFormat)
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structural_tag = json.dumps(s_tag_obj)
689
690
691
692
693
694
695
696

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self._get_guided_json_from_tool() or self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
697
            whitespace_pattern=self.guided_whitespace_pattern,
698
            structural_tag=self.structural_tag,
699
        )
700

701
702
703
704
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
705
        return SamplingParams.from_optional(
706
            n=self.n,
707
            best_of=self.best_of,
708
709
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
710
711
712
713
714
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
715
            seed=self.seed,
716
717
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
718
            logprobs=self.top_logprobs if self.logprobs else None,
719
            prompt_logprobs=prompt_logprobs,
720
            ignore_eos=self.ignore_eos,
721
            max_tokens=max_tokens,
722
            min_tokens=self.min_tokens,
723
724
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
725
726
            logits_processors=get_logits_processors(self.logits_processors,
                                                    logits_processor_pattern),
727
            include_stop_str_in_output=self.include_stop_str_in_output,
728
            truncate_prompt_tokens=self.truncate_prompt_tokens,
729
730
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
731
            guided_decoding=guided_decoding,
Robert Shaw's avatar
Robert Shaw committed
732
            logit_bias=self.logit_bias,
733
            bad_words= self.bad_words,
734
            allowed_token_ids=self.allowed_token_ids,
735
736
            extra_args=extra_args or None,
        )
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753

    def _get_guided_json_from_tool(
            self) -> Optional[Union[str, dict, BaseModel]]:
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
                raise ValueError(
                    f"Tool '{tool_name}' has not been passed in `tools`.")
            tool = tools[tool_name]
            return tool.parameters

754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        if self.tool_choice == "required":
            # Pydantic schema generation cannot be used since the JSON schema
            # has to be constructed for a specific instantiation of a tool list
            # so that parameters of a function are correctly generated
            # based on the chosen function name
            def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
                return {
                    "properties": {
                        "name": {
                            "type": "string",
                            "enum": [tool.function.name]
                        },
                        # parameters are always generated as '{}' in the final
                        # output if they are missing from the request
                        # (i.e. are None or '{}') so the schema is
                        # updated to produce an empty object in that case
                        "parameters": tool.function.parameters
                        if tool.function.parameters else {
                            "type": "object",
                            "properties": {}
                        }
                    },
                    "required": ["name", "parameters"]
                }

779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
            def get_tool_schema_defs(
                    tools: list[ChatCompletionToolsParam]) -> dict:
                all_defs = dict[str, dict[str, Any]]()
                for tool in tools:
                    if tool.function.parameters is None:
                        continue
                    defs = tool.function.parameters.pop("$defs", {})
                    for def_name, def_schema in defs.items():
                        if def_name in all_defs and all_defs[
                                def_name] != def_schema:
                            raise ValueError(
                                f"Tool definition '{def_name}' has "
                                "multiple schemas, which is not "
                                "supported.")
                        else:
                            all_defs[def_name] = def_schema
                return all_defs

797
798
799
800
801
802
803
804
            json_schema = {
                "type": "array",
                "minItems": 1,
                "items": {
                    "type": "object",
                    "anyOf": [get_tool_schema(tool) for tool in self.tools]
                }
            }
805
806
807
            json_schema_defs = get_tool_schema_defs(self.tools)
            if json_schema_defs:
                json_schema["$defs"] = json_schema_defs
808
809
            return json_schema

810
        return None
811

812
    @model_validator(mode="before")
813
    @classmethod
814
815
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
816
            raise ValueError(
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
                "Stream options can only be defined when `stream=True`.")

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (top_logprobs := data.get("top_logprobs")) is not None:
            if top_logprobs < 0:
                raise ValueError("`top_logprobs` must be a positive value.")

836
            if top_logprobs > 0 and not data.get("logprobs"):
837
838
839
840
841
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
842

843
844
845
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
846
847
848
        if isinstance(data, ValueError):
            raise data

849
850
851
852
853
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
854
        # you can only use one kind of guided decoding
855
856
857
858
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
859
        # you can only either use guided decoding or tools, not both
860
861
862
863
864
        if guide_count > 1 and data.get("tool_choice", "none") not in (
                "none",
                "auto",
                "required",
        ):
865
866
867
868
869
870
            raise ValueError(
                "You can only either use guided decoding or tools, not both.")
        return data

    @model_validator(mode="before")
    @classmethod
871
872
873
874
    def check_tool_usage(cls, data):

        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
875
        if "tool_choice" not in data and data.get("tools"):
876
877
            data["tool_choice"] = "auto"

878
        # if "tool_choice" is "none" -- no validation is needed for tools
879
880
881
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

882
        # if "tool_choice" is specified -- validation
883
        if "tool_choice" in data and data["tool_choice"] is not None:
884
885

            # ensure that if "tool choice" is specified, tools are present
886
887
888
            if "tools" not in data or data["tools"] is None:
                raise ValueError(
                    "When using `tool_choice`, `tools` must be set.")
889
890

            # make sure that tool choice is either a named tool
891
892
893
894
            # OR that it's set to "auto" or "required"
            if data["tool_choice"] not in [
                    "auto", "required"
            ] and not isinstance(data["tool_choice"], dict):
895
                raise ValueError(
896
897
898
899
                    f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
                    'Only named tools, "none", "auto" or "required" '\
                    'are supported.'
                )
900

901
902
903
904
905
906
907
908
909
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
            if data["tool_choice"] == "required" and isinstance(
                    data["tools"], list) and len(data["tools"]) == 0:
                data["tool_choice"] = "none"
                del data["tools"]
                return data

910
911
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
912
913
            correct_usage_message = 'Correct usage: `{"type": "function",' \
                ' "function": {"name": "my_function"}}`'
914
915
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
916
917
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
918
                    raise ValueError(
919
920
921
922
923
924
925
926
                        f"Invalid value for `function`: `{function}` in "
                        f"`tool_choice`! {correct_usage_message}")
                if "name" not in function:
                    raise ValueError(f"Expected field `name` in `function` in "
                                     f"`tool_choice`! {correct_usage_message}")
                function_name = function["name"]
                if not isinstance(function_name,
                                  str) or len(function_name) == 0:
927
                    raise ValueError(
928
929
                        f"Invalid `name` in `function`: `{function_name}`"
                        f" in `tool_choice`! {correct_usage_message}")
930
                for tool in data["tools"]:
931
                    if tool["function"]["name"] == function_name:
932
933
934
935
936
937
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
                        " of the specified `tools`")
938
939
        return data

940
941
942
943
944
945
946
947
948
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

949
950
951
952
953
954
955
956
957
958
959
960
961
962
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
                    "this instance of vLLM, which uses engine V0.")
            if not isinstance(data["cache_salt"],
                              str) or not data["cache_salt"]:
                raise ValueError("Parameter 'cache_salt' must be a "
                                 "non-empty string if provided.")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
963

964
class CompletionRequest(OpenAIBaseModel):
965
966
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
967
    model: Optional[str] = None
968
969
    prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
    prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
970
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
971
972
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
973
    logit_bias: Optional[dict[str, float]] = None
974
975
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
976
    n: int = 1
977
    presence_penalty: Optional[float] = 0.0
978
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
979
    stop: Optional[Union[str, list[str]]] = []
980
    stream: Optional[bool] = False
981
    stream_options: Optional[StreamOptions] = None
982
    suffix: Optional[str] = None
983
984
    temperature: Optional[float] = None
    top_p: Optional[float] = None
Zhuohan Li's avatar
Zhuohan Li committed
985
    user: Optional[str] = None
986

987
    # --8<-- [start:completion-sampling-params]
988
    use_beam_search: bool = False
989
990
991
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
992
    length_penalty: float = 1.0
993
    stop_token_ids: Optional[list[int]] = []
994
995
996
997
998
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
999
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1000
    allowed_token_ids: Optional[list[int]] = None
1001
    prompt_logprobs: Optional[int] = None
1002
    # --8<-- [end:completion-sampling-params]
1003

1004
    # --8<-- [start:completion-extra-params]
1005
1006
    add_special_tokens: bool = Field(
        default=True,
1007
        description=(
1008
1009
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
1010
    )
1011
    response_format: Optional[AnyResponseFormat] = Field(
1012
        default=None,
1013
1014
1015
1016
1017
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1018
1019
1020
    )
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
1021
        description="If specified, the output will follow the JSON schema.",
1022
1023
1024
1025
1026
1027
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
1028
    guided_choice: Optional[list[str]] = Field(
1029
1030
1031
1032
1033
1034
1035
1036
1037
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
1038
1039
1040
1041
1042
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be one of "
1043
1044
            "'outlines' / 'lm-format-enforcer'"),
    )
1045
1046
1047
1048
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
1049
1050
            "for guided json decoding."),
    )
1051
1052
1053
1054
1055
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1056
1057
            "if the served model does not use priority scheduling."),
    )
1058
1059
1060
1061
1062
1063
1064
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
            "{'param': 'value'}}."))
1076

1077
1078
1079
1080
1081
1082
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
            "that are not JSON-encodable can be identified."))
1083
1084
1085
1086
1087
1088
1089
1090
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
            "need to map generated text back to input tokens."))
1091

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit). Not supported by vLLM engine V0."))

Robert Shaw's avatar
Robert Shaw committed
1102
1103
1104
1105
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.")

1106
1107
1108
1109
1110
1111
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
        description=("Additional request parameters with string or "
                     "numeric values, used by custom extensions."),
    )

1112
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1113

1114
1115
1116
1117
1118
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1119
        "top_k": 0,
1120
1121
1122
1123
        "min_p": 0.0,
    }

    def to_beam_search_params(
1124
1125
1126
        self,
        max_tokens: int,
        default_sampling_params: Optional[dict] = None,
1127
    ) -> BeamSearchParams:
1128

1129
1130
        if default_sampling_params is None:
            default_sampling_params = {}
1131
        n = self.n if self.n is not None else 1
1132
1133
1134

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1135
1136
1137
1138
1139
1140

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1141
            length_penalty=self.length_penalty,
1142
1143
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1144

1145
    def to_sampling_params(
1146
        self,
1147
        max_tokens: int,
1148
1149
1150
        logits_processor_pattern: Optional[str],
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
1151

1152
1153
        if default_sampling_params is None:
            default_sampling_params = {}
1154

1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])

1174
1175
1176
1177
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1178
1179
        echo_without_generation = self.echo and self.max_tokens == 0

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
        guided_json_object = None
        if (self.response_format is not None
                and self.response_format.type == "json_object"):
            guided_json_object = True

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
1192
1193
            whitespace_pattern=self.guided_whitespace_pattern,
        )
1194

1195
1196
1197
1198
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1199
        return SamplingParams.from_optional(
1200
            n=self.n,
1201
            best_of=self.best_of,
1202
1203
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1204
1205
1206
1207
1208
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1209
            seed=self.seed,
1210
1211
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1212
            logprobs=self.logprobs,
1213
            ignore_eos=self.ignore_eos,
1214
            max_tokens=max_tokens if not echo_without_generation else 1,
1215
            min_tokens=self.min_tokens,
1216
            prompt_logprobs=prompt_logprobs,
1217
            skip_special_tokens=self.skip_special_tokens,
1218
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1219
            include_stop_str_in_output=self.include_stop_str_in_output,
1220
1221
            logits_processors=get_logits_processors(self.logits_processors,
                                                    logits_processor_pattern),
1222
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1223
1224
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
1225
1226
            guided_decoding=guided_decoding,
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1227
            allowed_token_ids=self.allowed_token_ids,
1228
1229
            extra_args=extra_args or None,
            )
1230

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

1245
1246
1247
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1259
1260
        return data

1261
1262
1263
1264
1265
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
            raise ValueError(
1266
1267
                "Stream options can only be defined when `stream=True`.")

1268
1269
        return data

1270
1271
1272
1273
1274
1275
1276
1277
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
        if data.get("prompt") is None and data.get("prompt_embeds") is None:
            raise ValueError(
                "At least one of `prompt` or `prompt_embeds` must be set.")
        return data

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
                    "this instance of vLLM, which uses engine V0.")
            if not isinstance(data["cache_salt"],
                              str) or not data["cache_salt"]:
                raise ValueError("Parameter 'cache_salt' must be a "
                                 "non-empty string if provided.")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1292

1293
class EmbeddingCompletionRequest(OpenAIBaseModel):
1294
1295
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1296
    model: Optional[str] = None
1297
    input: Union[list[int], list[list[int]], str, list[str]]
1298
    encoding_format: Literal["float", "base64"] = "float"
1299
1300
    dimensions: Optional[int] = None
    user: Optional[str] = None
1301
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1302

1303
    # --8<-- [start:embedding-extra-params]
1304
1305
1306
1307
1308
1309
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
1310
1311
1312
1313
1314
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1315
1316
            "if the served model does not use priority scheduling."),
    )
1317
1318
1319
1320
1321
1322
1323
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
1324
    normalize: Optional[bool] = None
1325

1326
    # --8<-- [end:embedding-extra-params]
1327

1328
    def to_pooling_params(self):
1329
1330
1331
1332
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
            normalize=self.normalize)
1333
1334


1335
class EmbeddingChatRequest(OpenAIBaseModel):
1336
    model: Optional[str] = None
1337
    messages: list[ChatCompletionMessageParam]
1338
1339
1340
1341

    encoding_format: Literal["float", "base64"] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
1342
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1343

1344
    # --8<-- [start:chat-embedding-extra-params]
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
1362
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
1363
        default=None,
1364
1365
1366
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."),
1367
    )
1368
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
1369
1370
1371
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1372
1373
1374
1375
1376
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1377
1378
            "if the served model does not use priority scheduling."),
    )
1379
1380
1381
1382
1383
1384
1385
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
1386
    normalize: Optional[bool] = None
1387
    # --8<-- [end:chat-embedding-extra-params]
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

    def to_pooling_params(self):
1399
1400
1401
1402
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
            normalize=self.normalize)
1403
1404
1405
1406


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

1407
1408
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426

T = TypeVar("T")


class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
    model: Optional[str] = None

    priority: int = Field(default=0)
    """
    The priority of the request (lower means earlier handling;
    default: 0). Any priority other than 0 will raise an error
    if the served model does not use priority scheduling.
    """
    data: T
    """
    When using plugins IOProcessor plugins, the actual input is processed
    by the plugin itself. Hence, we use a generic type for the request data
    """
1427
    softmax: bool = True
1428
1429

    def to_pooling_params(self):
1430
        return PoolingParams(task="encode", softmax=self.softmax)
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):

    request_id: Optional[str] = None
    """
    The request_id associated with this response
    """
    created_at: int = Field(default_factory=lambda: int(time.time()))

    data: T
    """
    When using plugins IOProcessor plugins, the actual output is generated
    by the plugin itself. Hence, we use a generic type for the response data
    """


PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest,
                       IOProcessorRequest]
1450

1451

1452
class ScoreRequest(OpenAIBaseModel):
1453
    model: Optional[str] = None
1454
1455
    text_1: Union[list[str], str, ScoreMultiModalParam]
    text_2: Union[list[str], str, ScoreMultiModalParam]
1456
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1457

1458
    # --8<-- [start:score-extra-params]
1459
1460
1461
1462
1463
1464

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1465
1466
1467
1468
1469
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1470
1471
            "if the served model does not use priority scheduling."),
    )
1472

1473
1474
    activation: Optional[bool] = None

1475
    # --8<-- [end:score-extra-params]
1476

1477
    def to_pooling_params(self):
1478
1479
1480
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            activation=self.activation)
1481
1482


1483
class RerankRequest(OpenAIBaseModel):
1484
    model: Optional[str] = None
1485
1486
    query: Union[str, ScoreMultiModalParam]
    documents: Union[list[str], ScoreMultiModalParam]
1487
    top_n: int = Field(default_factory=lambda: 0)
1488
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1489

1490
    # --8<-- [start:rerank-extra-params]
1491
1492
1493
1494
1495
1496

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1497
1498
1499
1500
1501
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1502
1503
            "if the served model does not use priority scheduling."),
    )
1504

1505
1506
    activation: Optional[bool] = None

1507
    # --8<-- [end:rerank-extra-params]
1508

1509
    def to_pooling_params(self):
1510
1511
1512
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            activation=self.activation)
1513
1514
1515


class RerankDocument(BaseModel):
1516
    text: Optional[str] = None
1517
    multi_modal: Optional[ScoreContentPartParam] = None
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1534
    results: list[RerankResult]
1535
1536


1537
class CompletionLogProbs(OpenAIBaseModel):
1538
1539
1540
1541
    text_offset: list[int] = Field(default_factory=list)
    token_logprobs: list[Optional[float]] = Field(default_factory=list)
    tokens: list[str] = Field(default_factory=list)
    top_logprobs: list[Optional[dict[str,
1542
                                     float]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1543
1544


1545
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1546
1547
    index: int
    text: str
1548
    logprobs: Optional[CompletionLogProbs] = None
1549
1550
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1551
1552
1553
1554
1555
1556
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
1557
    token_ids: Optional[list[int]] = None  # For response
1558
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
1559
    prompt_token_ids: Optional[list[int]] = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1560
1561


1562
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1563
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1564
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1565
1566
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1567
    choices: list[CompletionResponseChoice]
1568
1569
1570
    service_tier: Optional[Literal["auto", "default", "flex", "scale",
                                   "priority"]] = None
    system_fingerprint: Optional[str] = None
Zhuohan Li's avatar
Zhuohan Li committed
1571
    usage: UsageInfo
1572
1573

    # vLLM-specific fields that are not in OpenAI spec
Robert Shaw's avatar
Robert Shaw committed
1574
1575
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None, description="KVTransfer parameters.")
Zhuohan Li's avatar
Zhuohan Li committed
1576
1577


1578
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1579
1580
    index: int
    text: str
1581
    logprobs: Optional[CompletionLogProbs] = None
1582
1583
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1584
1585
1586
1587
1588
1589
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
1590
1591
1592
1593
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
    prompt_token_ids: Optional[list[int]] = None
    token_ids: Optional[list[int]] = None
Zhuohan Li's avatar
Zhuohan Li committed
1594
1595


1596
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1597
1598
1599
1600
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1601
    choices: list[CompletionResponseStreamChoice]
1602
    usage: Optional[UsageInfo] = Field(default=None)
1603
1604


1605
class EmbeddingResponseData(OpenAIBaseModel):
1606
1607
    index: int
    object: str = "embedding"
1608
    embedding: Union[list[float], str]
1609
1610


1611
class EmbeddingResponse(OpenAIBaseModel):
1612
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1613
1614
1615
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1616
    data: list[EmbeddingResponseData]
1617
1618
1619
    usage: UsageInfo


1620
1621
1622
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1623
    data: Union[list[list[float]], list[float], str]
1624
1625
1626
1627
1628
1629
1630


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1631
    data: list[PoolingResponseData]
1632
1633
1634
    usage: UsageInfo


1635
1636
1637
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1638
    score: float
1639
1640
1641
1642
1643
1644
1645


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1646
    data: list[ScoreResponseData]
1647
1648
1649
    usage: UsageInfo


1650
1651
1652
1653
1654
1655
class ClassificationRequest(OpenAIBaseModel):
    model: Optional[str] = None
    input: Union[list[str], str]
    truncate_prompt_tokens: Optional[int] = None
    user: Optional[str] = None

1656
    # --8<-- [start:classification-extra-params]
1657
1658
1659
1660
1661
1662
1663
1664
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."),
    )

1665
1666
    activation: Optional[bool] = None

1667
    # --8<-- [end:classification-extra-params]
1668
1669

    def to_pooling_params(self):
1670
1671
1672
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            activation=self.activation)
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690


class ClassificationData(OpenAIBaseModel):
    index: int
    label: Optional[str]
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1691
1692
1693
1694
1695
1696
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1697
    id: str = Field(default_factory=make_tool_call_id)
1698
1699
1700
1701
    type: Literal["function"] = "function"
    function: FunctionCall


1702
1703
1704
1705
1706
1707
1708
class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1709
1710
    id: Optional[str] = None
    type: Optional[Literal["function"]] = None
1711
1712
1713
1714
1715
1716
1717
1718
1719
    index: int
    function: Optional[DeltaFunctionCall] = None


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1720
    tool_calls: list[ToolCall]
1721
1722
1723
1724
1725
1726

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


1727
class ChatMessage(OpenAIBaseModel):
1728
    role: str
1729
    content: Optional[str] = None
1730
1731
1732
1733
    refusal: Optional[str] = None
    annotations: Optional[OpenAIAnnotation] = None
    audio: Optional[OpenAIChatCompletionAudio] = None
    function_call: Optional[FunctionCall] = None
1734
    tool_calls: list[ToolCall] = Field(default_factory=list)
1735

1736
1737
1738
    # vLLM-specific fields that are not in OpenAI spec
    reasoning_content: Optional[str] = None

1739

1740
1741
1742
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1743
    bytes: Optional[list[int]] = None
1744
1745
1746


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1747
1748
1749
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
    field_names: ClassVar[Optional[set[str]]] = None
1750
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1751
1752
1753


class ChatCompletionLogProbs(OpenAIBaseModel):
1754
    content: Optional[list[ChatCompletionLogProbsContent]] = None
1755
1756


1757
class ChatCompletionResponseChoice(OpenAIBaseModel):
1758
1759
    index: int
    message: ChatMessage
1760
    logprobs: Optional[ChatCompletionLogProbs] = None
1761
1762
1763
    # per OpenAI spec this is the default
    finish_reason: Optional[str] = "stop"
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1764
    stop_reason: Optional[Union[int, str]] = None
1765
1766
1767
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
    token_ids: Optional[list[int]] = None
1768
1769


1770
class ChatCompletionResponse(OpenAIBaseModel):
1771
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1772
    object: Literal["chat.completion"] = "chat.completion"
1773
1774
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1775
    choices: list[ChatCompletionResponseChoice]
1776
1777
1778
    service_tier: Optional[Literal["auto", "default", "flex", "scale",
                                   "priority"]] = None
    system_fingerprint: Optional[str] = None
1779
    usage: UsageInfo
1780
1781

    # vLLM-specific fields that are not in OpenAI spec
1782
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
1783
    prompt_token_ids: Optional[list[int]] = None
Robert Shaw's avatar
Robert Shaw committed
1784
1785
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None, description="KVTransfer parameters.")
1786
1787


1788
class DeltaMessage(OpenAIBaseModel):
1789
1790
    role: Optional[str] = None
    content: Optional[str] = None
1791
    reasoning_content: Optional[str] = None
1792
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
1793
1794


1795
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1796
1797
    index: int
    delta: DeltaMessage
1798
    logprobs: Optional[ChatCompletionLogProbs] = None
1799
    finish_reason: Optional[str] = None
1800
    stop_reason: Optional[Union[int, str]] = None
1801
1802
    # not part of the OpenAI spec but for tracing the tokens
    token_ids: Optional[list[int]] = None
1803
1804


1805
class ChatCompletionStreamResponse(OpenAIBaseModel):
1806
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1807
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1808
1809
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1810
    choices: list[ChatCompletionResponseStreamChoice]
1811
    usage: Optional[UsageInfo] = Field(default=None)
1812
1813
    # not part of the OpenAI spec but for tracing the tokens
    prompt_token_ids: Optional[list[int]] = None
1814
1815


1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int


class OutputTokensDetails(OpenAIBaseModel):
    reasoning_tokens: int


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855


class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
    # incomplete_details: Optional[IncompleteDetails] = None
    instructions: Optional[str] = None
    metadata: Optional[Metadata] = None
    model: str
    object: Literal["response"] = "response"
1856
    output: list[ResponseOutputItem]
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
    max_tool_calls: Optional[int] = None
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
    text: Optional[ResponseTextConfig] = None
1871
    top_logprobs: Optional[int] = None
1872
    truncation: Literal["auto", "disabled"]
1873
    usage: Optional[ResponseUsage] = None
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
    user: Optional[str] = None

    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
1885
        usage: Optional[ResponseUsage] = None,
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
    ) -> "ResponsesResponse":
        return cls(
            id=request.request_id,
            created_at=created_time,
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


1915
1916
1917
1918
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
                              ScoreRequest, RerankRequest]


1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

1938
    # The parameters of the request.
1939
    body: BatchRequestInputBody
1940

1941
1942
1943
1944
    @field_validator('body', mode='plain')
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
1945
        url: str = info.data["url"]
1946
1947
1948
1949
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
1950
        if url.endswith("/score"):
1951
            return ScoreRequest.model_validate(value)
1952
1953
1954
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
1955

1956

1957
1958
1959
1960
1961
1962
1963
1964
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
1965
    body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
1966
                         ScoreResponse, RerankResponse]] = None
1967
1968


1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

1980
    response: Optional[BatchResponseData]
1981
1982
1983
1984

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Optional[Any]
1985
1986


1987
class TokenizeCompletionRequest(OpenAIBaseModel):
1988
    model: Optional[str] = None
1989
1990
    prompt: str

1991
1992
1993
1994
1995
1996
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
1997
1998
1999
2000
2001
    return_token_strs: Optional[bool] = Field(
        default=False,
        description=("If true, also return the token strings "
                     "corresponding to the token ids."),
    )
2002
2003
2004


class TokenizeChatRequest(OpenAIBaseModel):
2005
    model: Optional[str] = None
2006
    messages: list[ChatCompletionMessageParam]
2007

2008
2009
2010
2011
2012
2013
2014
    add_generation_prompt: bool = Field(
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
2015
2016
2017
2018
2019
    return_token_strs: Optional[bool] = Field(
        default=False,
        description=("If true, also return the token strings "
                     "corresponding to the token ids."),
    )
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
2046
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
2047
        default=None,
2048
2049
2050
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."),
2051
    )
2052
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
2053
2054
2055
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
2056
2057
2058
2059
    tools: Optional[list[ChatCompletionToolsParam]] = Field(
        default=None,
        description=("A list of tools the model may call."),
    )
2060

2061
2062
2063
2064
2065
2066
2067
2068
2069
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

2070
2071

TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
2072
2073
2074
2075
2076


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2077
    tokens: list[int]
2078
    token_strs: Optional[list[str]] = None
2079
2080
2081


class DetokenizeRequest(OpenAIBaseModel):
2082
    model: Optional[str] = None
2083
    tokens: list[int]
2084
2085
2086
2087


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2088
2089


2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
class TokenizerInfoResponse(OpenAIBaseModel):
    """
    Response containing tokenizer configuration 
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2100
class LoadLoRAAdapterRequest(BaseModel):
2101
2102
2103
2104
    lora_name: str
    lora_path: str


2105
class UnloadLoRAAdapterRequest(BaseModel):
2106
2107
    lora_name: str
    lora_int_id: Optional[int] = Field(default=None)
2108
2109
2110
2111
2112
2113
2114
2115
2116


## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json",
                                         "vtt"]


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2117
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2118
2119
2120
2121
2122
2123
2124

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2125
    model: Optional[str] = None
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
    """ID of the model to use.
    """

    language: Optional[str] = None
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2153
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2154
2155
2156
2157
2158
2159
2160
2161
2162
        alias="timestamp_granularities[]", default=[])
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2163
    stream: Optional[bool] = False
2164
    """When set, it will enable output to be streamed in a similar fashion
2165
    as the Chat Completion endpoint.
2166
    """
2167
    # --8<-- [start:transcription-extra-params]
2168
2169
2170
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
2171
2172
2173
2174
2175
2176

    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
        description=("Additional request parameters with string or "
                     "numeric values, used by custom extensions."),
    )
2177
    # --8<-- [end:transcription-extra-params]
2178

2179
2180
2181
2182
2183
2184
2185
    to_language: Optional[str] = None
    """The language of the output audio we transcribe to.

    Please note that this is not currently used by supported models at this 
    time, but it is a placeholder for future use, matching translation api.
    """

2186
    # --8<-- [start:transcription-sampling-params]
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

    top_p: Optional[float] = None
2197
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2198
2199
2200
2201
2202
2203
2204
    smallest possible set whose cumulative probability exceeds `p`.
    """

    top_k: Optional[int] = None
    """Limits sampling to the `k` most probable tokens at each step."""

    min_p: Optional[float] = None
2205
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
    minimum likelihood threshold during sampling.
    """

    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

    frequency_penalty: Optional[float] = 0.0
    """The frequency penalty to use for sampling."""

    repetition_penalty: Optional[float] = None
    """The repetition penalty to use for sampling."""

    presence_penalty: Optional[float] = 0.0
    """The presence penalty to use for sampling."""
2220
    # --8<-- [end:transcription-sampling-params]
2221

2222
2223
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2224
2225
2226
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2227
        "top_k": 0,
2228
        "min_p": 0.0,
2229
2230
2231
2232
2233
2234
    }

    def to_sampling_params(
            self,
            default_max_tokens: int,
            default_sampling_params: Optional[dict] = None) -> SamplingParams:
2235

2236
2237
2238
2239
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2240

2241
2242
2243
2244
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
2259
2260

        return SamplingParams.from_optional(temperature=temperature,
2261
                                            max_tokens=max_tokens,
2262
2263
2264
2265
2266
2267
2268
                                            seed=self.seed,
                                            top_p=top_p,
                                            top_k=top_k,
                                            min_p=min_p,
                                            frequency_penalty=self.frequency_penalty,
                                            repetition_penalty=repetition_penalty,
                                            presence_penalty=self.presence_penalty,
2269
2270
                                            output_kind=RequestOutputKind.DELTA
                                            if self.stream \
2271
2272
                                            else RequestOutputKind.FINAL_ONLY,
                                            extra_args=self.vllm_xargs)
2273
2274
2275

    @model_validator(mode="before")
    @classmethod
2276
2277
2278
2279
2280
2281
2282
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2283
2284
2285
2286
2287
2288
2289
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
            raise ValueError(
                "Stream options can only be defined when `stream=True`.")

        return data
2290
2291
2292


# Transcription response objects
2293
2294
2295
2296
2297
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2298
2299
2300
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2301
    usage: TranscriptionUsageAudio
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2353
    tokens: list[int]
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2367
    segments: Optional[list[TranscriptionSegment]] = None
2368
2369
    """Segments of the transcribed text and their corresponding details."""

2370
    words: Optional[list[TranscriptionWord]] = None
2371
    """Extracted words and their corresponding timestamps."""
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

    model: Optional[str] = None
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2419
2420
2421
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
    language: Optional[str] = None
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2441
2442
2443
2444
2445
2446
2447
2448
    to_language: Optional[str] = None
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2449
    stream: Optional[bool] = False
2450
    """Custom field not present in the original OpenAI definition. When set,
2451
    it will enable output to be streamed in a similar fashion as the Chat
2452
    Completion endpoint.
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
    """
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
            self,
            default_max_tokens: int,
            default_sampling_params: Optional[dict] = None) -> SamplingParams:
2468

2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])

        return SamplingParams.from_optional(temperature=temperature,
                                            max_tokens=max_tokens,
2480
                                            seed=self.seed,
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
                                            output_kind=RequestOutputKind.DELTA
                                            if self.stream \
                                            else RequestOutputKind.FINAL_ONLY)

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
            raise ValueError(
                "Stream options can only be defined when `stream=True`.")

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

    segments: Optional[list[TranslationSegment]] = None
    """Segments of the translated text and their corresponding details."""

    words: Optional[list[TranslationWord]] = None
    """Extracted words and their corresponding timestamps."""