protocol.py 102 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, Union
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningItem,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
40
from openai.types.responses import (
41
42
43
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
44
from openai.types.responses import (
45
46
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
47
from openai.types.responses.response_reasoning_item import (
48
49
    Content as ResponseReasoningTextContent,
)
50
51
52
53
54

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
55
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
56

57
from openai.types.responses.response import IncompleteDetails, ToolChoice
58
59
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
60
61
62
63
64
65
66
67
68
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
    ValidationInfo,
    field_validator,
    model_validator,
)
69
from typing_extensions import TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
70

71
from vllm import envs
72
73
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
74
from vllm.logger import init_logger
75
from vllm.logprobs import Logprob
76
from vllm.pooling_params import PoolingParams
77
78
79
80
81
82
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
83
from vllm.utils import random_uuid, resolve_obj_by_qualname
84

85
86
logger = init_logger(__name__)

87
_LONG_INFO = torch.iinfo(torch.long)
88

Zhuohan Li's avatar
Zhuohan Li committed
89

90
class OpenAIBaseModel(BaseModel):
91
92
93
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

94
    # Cache class field names
95
    field_names: ClassVar[Optional[set[str]]] = None
96

97
    @model_validator(mode="wrap")
98
    @classmethod
99
100
101
102
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
103
104
        field_names = cls.field_names
        if field_names is None:
105
106
107
108
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
109
                if alias := getattr(field, "alias", None):
110
111
112
113
114
115
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
116
                "The following fields were present in the request but ignored: %s",
117
118
                data.keys() - field_names,
            )
119
        return result
120
121


122
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
123
124
125
    message: str
    type: str
    param: Optional[str] = None
126
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
127
128


129
130
131
132
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


133
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
134
135
136
137
138
139
140
141
142
143
144
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
145
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
146
147


148
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
149
150
151
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
152
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
153
154
    root: Optional[str] = None
    parent: Optional[str] = None
155
    max_model_len: Optional[int] = None
156
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
157
158


159
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
160
    object: str = "list"
161
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
162
163


164
165
166
167
class PromptTokenUsageInfo(OpenAIBaseModel):
    cached_tokens: Optional[int] = None


168
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
169
170
171
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0
172
    prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Zhuohan Li's avatar
Zhuohan Li committed
173
174


175
176
177
178
179
class RequestResponseMetadata(BaseModel):
    request_id: str
    final_usage_info: Optional[UsageInfo] = None


180
181
182
183
184
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
185
    json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
186
187
188
    strict: Optional[bool] = None


189
190
191
192
class StructuralTag(OpenAIBaseModel):
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
193
194
195
    structural_tag_schema: Optional[dict[str, Any]] = Field(
        default=None, alias="schema"
    )
196
197
198
199
200
201
202
203
204
    end: str


class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    structures: list[StructuralTag]
    triggers: list[str]


205
class ResponseFormat(OpenAIBaseModel):
206
    # type must be "json_schema", "json_object", or "text"
207
208
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchemaResponseFormat] = None
209
210


211
212
213
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]


214
class StreamOptions(OpenAIBaseModel):
215
    include_usage: Optional[bool] = True
216
    continuous_usage_stats: Optional[bool] = False
217
218


219
220
221
class FunctionDefinition(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
222
    parameters: Optional[dict[str, Any]] = None
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


239
240
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
241
242
class LogitsProcessorConstructor(BaseModel):
    qualname: str
243
244
    args: Optional[list[Any]] = None
    kwargs: Optional[dict[str, Any]] = None
245

246
247
    model_config = ConfigDict(extra="forbid")

248

249
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
250
251


252
253
254
def get_logits_processors(
    processors: Optional[LogitsProcessors], pattern: Optional[str]
) -> Optional[list[Any]]:
255
256
257
    if processors and pattern:
        logits_processors = []
        for processor in processors:
258
            qualname = processor if isinstance(processor, str) else processor.qualname
259
260
261
262
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
263
264
                    "for more information."
                )
265
266
267
268
269
270
271
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
272
273
274
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
275
276
277
278
279
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
280
            "server. See --logits-processor-pattern engine argument "
281
282
            "for more information."
        )
283
284
285
    return None


286
287
288
ResponseInputOutputItem: TypeAlias = Union[
    ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall
]
289
290


291
292
293
294
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
    background: Optional[bool] = False
295
296
297
298
299
300
301
302
303
304
305
306
    include: Optional[
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
    ] = None
307
    input: Union[str, list[ResponseInputOutputItem]]
308
309
310
311
312
313
314
315
316
    instructions: Optional[str] = None
    max_output_tokens: Optional[int] = None
    max_tool_calls: Optional[int] = None
    metadata: Optional[Metadata] = None
    model: Optional[str] = None
    parallel_tool_calls: Optional[bool] = True
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
317
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    store: Optional[bool] = True
    stream: Optional[bool] = False
    temperature: Optional[float] = None
    text: Optional[ResponseTextConfig] = None
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
    top_logprobs: Optional[int] = 0
    top_p: Optional[float] = None
    truncation: Optional[Literal["auto", "disabled"]] = "disabled"
    user: Optional[str] = None

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
335
336
            "through out the inference process and return in response."
        ),
337
338
339
340
341
342
343
344
345
346
    )
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
347
348
            "if the served model does not use priority scheduling."
        ),
349
    )
350
351
352
353
354
355
356
357
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
358
359
360
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
361
362
363
364
365
366

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
            "response object. Currently only supported for non-streaming "
367
368
369
            "non-background and gpt-oss only. "
        ),
    )
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
390
391
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
392
393
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
394
395
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
396
        stop_token_ids = default_sampling_params.get("stop_token_ids")
397
398

        # Structured output
399
        structured_outputs = None
400
401
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
402
403
404
405
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
406
                structured_outputs = StructuredOutputsParams(
407
408
                    json=response_format.schema_
                )
409
410
411
412
413
414
415
416
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
417
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
418
            stop_token_ids=stop_token_ids,
419
420
421
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
422
            structured_outputs=structured_outputs,
423
424
        )

425
426
427
428
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
429
430
431
432
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
433

434
435
436
437
438
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
439
            raise ValueError("background can only be used when `store` is true")
440
441
442
443
444
445
446
447
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

448
449
450
451
452
453
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
454
455
456
457
458
459
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
460
461
        return data

462

463
class ChatCompletionRequest(OpenAIBaseModel):
464
465
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
466
    messages: list[ChatCompletionMessageParam]
467
    model: Optional[str] = None
468
    frequency_penalty: Optional[float] = 0.0
469
    logit_bias: Optional[dict[str, float]] = None
470
    logprobs: Optional[bool] = False
471
    top_logprobs: Optional[int] = 0
472
473
    max_tokens: Optional[int] = Field(
        default=None,
474
475
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
476
    )
477
    max_completion_tokens: Optional[int] = None
478
479
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
480
    response_format: Optional[AnyResponseFormat] = None
481
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
482
    stop: Optional[Union[str, list[str]]] = []
Zhuohan Li's avatar
Zhuohan Li committed
483
    stream: Optional[bool] = False
484
    stream_options: Optional[StreamOptions] = None
485
486
    temperature: Optional[float] = None
    top_p: Optional[float] = None
487
    tools: Optional[list[ChatCompletionToolsParam]] = None
488
489
490
491
492
493
494
495
    tool_choice: Optional[
        Union[
            Literal["none"],
            Literal["auto"],
            Literal["required"],
            ChatCompletionNamedToolChoiceParam,
        ]
    ] = "none"
496
497
    reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
    include_reasoning: bool = True
498

499
    # NOTE this will be ignored by vLLM -- the model determines the behavior
500
    parallel_tool_calls: Optional[bool] = False
Zhuohan Li's avatar
Zhuohan Li committed
501
    user: Optional[str] = None
502

503
    # --8<-- [start:chat-completion-sampling-params]
504
    best_of: Optional[int] = None
505
    use_beam_search: bool = False
506
507
508
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
509
    length_penalty: float = 1.0
510
    stop_token_ids: Optional[list[int]] = []
511
512
513
514
515
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
516
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
517
    prompt_logprobs: Optional[int] = None
518
    allowed_token_ids: Optional[list[int]] = None
519
    bad_words: list[str] = Field(default_factory=list)
520
    # --8<-- [end:chat-completion-sampling-params]
521

522
    # --8<-- [start:chat-completion-extra-params]
523
    echo: bool = Field(
524
525
526
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
527
528
            "if they belong to the same role."
        ),
529
    )
530
    add_generation_prompt: bool = Field(
531
        default=True,
532
533
534
535
536
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
537
    )
538
539
    continue_final_message: bool = Field(
        default=False,
540
541
542
543
544
545
546
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
547
    )
548
    add_special_tokens: bool = Field(
549
550
551
552
553
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
554
            "special tokens so this should be set to false (as is the "
555
556
            "default)."
        ),
557
    )
558
    documents: Optional[list[dict[str, str]]] = Field(
559
        default=None,
560
561
562
563
564
565
566
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
567
568
569
570
571
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
572
573
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
574
575
            "does not define one."
        ),
576
    )
577
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
578
        default=None,
579
580
        description=(
            "Additional keyword args to pass to the template renderer. "
581
582
            "Will be accessible by the chat template."
        ),
583
    )
584
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
585
586
587
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
588
    structured_outputs: Optional[StructuredOutputsParams] = Field(
589
        default=None,
590
        description="Additional kwargs for structured outputs",
591
    )
592
593
594
595
596
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
597
598
            "Please pass `json` to `structured_outputs` instead."
        ),
599
600
601
602
603
604
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
605
606
            "Please pass `regex` to `structured_outputs` instead."
        ),
607
608
609
610
611
612
    )
    guided_choice: Optional[list[str]] = Field(
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
613
614
            "Please pass `choice` to `structured_outputs` instead."
        ),
615
616
617
618
619
620
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
621
622
            "Please pass `grammar` to `structured_outputs` instead."
        ),
623
624
625
626
627
628
    )
    structural_tag: Optional[str] = Field(
        default=None,
        description=(
            "`structural_tag` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
629
630
            "Please pass `structural_tag` to `structured_outputs` instead."
        ),
631
632
633
634
635
636
    )
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
637
638
            "Please remove it from your request."
        ),
639
640
641
642
643
644
645
646
647
    )
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
648
649
650
651
652
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
653
654
            "if the served model does not use priority scheduling."
        ),
655
    )
656
657
658
659
660
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
661
662
            "through out the inference process and return in response."
        ),
663
    )
664
665
666
667
668
669
670
671
672
673
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
674
675
676
            "{'param': 'value'}}."
        ),
    )
677
678
679
680
681
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
682
683
684
            "that are not JSON-encodable can be identified."
        ),
    )
685
686
687
688
689
690
691
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
692
693
694
            "need to map generated text back to input tokens."
        ),
    )
695
696
697
698
699
700
701
702
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
703
704
705
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
Robert Shaw's avatar
Robert Shaw committed
706
707
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
708
709
        description="KVTransfer parameters used for disaggregated serving.",
    )
710

711
712
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
713
714
715
716
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
717
718
    )

719
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
720

721
722
723
724
725
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
726
        "top_k": 0,
727
728
729
730
        "min_p": 0.0,
    }

    def to_beam_search_params(
731
732
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
733
        n = self.n if self.n is not None else 1
734
735
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
736
737
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
738
739
740
741
742
743

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
744
            length_penalty=self.length_penalty,
745
746
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
747

748
    def to_sampling_params(
749
        self,
750
        max_tokens: int,
751
        logits_processor_pattern: Optional[str],
752
        default_sampling_params: dict,
753
    ) -> SamplingParams:
754
755
756
757
758
759
760
761
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
762
763
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
764
765
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
766
767
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
768
769
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
770
771
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
772
773
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
774
775
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
776

777
778
779
780
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

781
782
783
784
785
786
787
788
789
790
791
792
793
794
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
                structural_tag=self.structural_tag,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        response_format = self.response_format
        json_schema_from_tool = self._get_json_schema_from_tool()
        if response_format is not None or json_schema_from_tool is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format is not None:
                if response_format.type == "json_object":
                    self.structured_outputs.json_object = True
                elif response_format.type == "json_schema":
                    json_schema = response_format.json_schema
                    assert json_schema is not None
                    self.structured_outputs.json = json_schema.json_schema
                elif response_format.type == "structural_tag":
                    structural_tag = response_format
                    assert structural_tag is not None and isinstance(
814
815
                        structural_tag, StructuralTagResponseFormat
                    )
816
                    s_tag_obj = structural_tag.model_dump(by_alias=True)
817
                    self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
818
819
820
821

            # Set structured output params for tool calling
            if json_schema_from_tool is not None:
                self.structured_outputs.json = json_schema_from_tool
822

823
824
825
826
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
827
        return SamplingParams.from_optional(
828
            n=self.n,
829
            best_of=self.best_of,
830
831
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
832
833
834
835
836
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
837
            seed=self.seed,
838
839
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
840
            logprobs=self.top_logprobs if self.logprobs else None,
841
            prompt_logprobs=prompt_logprobs,
842
            ignore_eos=self.ignore_eos,
843
            max_tokens=max_tokens,
844
            min_tokens=self.min_tokens,
845
846
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
847
848
849
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
850
            include_stop_str_in_output=self.include_stop_str_in_output,
851
            truncate_prompt_tokens=self.truncate_prompt_tokens,
852
853
854
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
855
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
856
            logit_bias=self.logit_bias,
857
            bad_words=self.bad_words,
858
            allowed_token_ids=self.allowed_token_ids,
859
860
            extra_args=extra_args or None,
        )
861

862
    def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]:
863
864
865
866
867
868
869
870
871
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
872
                raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
873
874
875
            tool = tools[tool_name]
            return tool.parameters

876
877
878
879
880
881
882
883
        if self.tool_choice == "required":
            # Pydantic schema generation cannot be used since the JSON schema
            # has to be constructed for a specific instantiation of a tool list
            # so that parameters of a function are correctly generated
            # based on the chosen function name
            def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
                return {
                    "properties": {
884
                        "name": {"type": "string", "enum": [tool.function.name]},
885
886
887
888
889
                        # parameters are always generated as '{}' in the final
                        # output if they are missing from the request
                        # (i.e. are None or '{}') so the schema is
                        # updated to produce an empty object in that case
                        "parameters": tool.function.parameters
890
891
                        if tool.function.parameters
                        else {"type": "object", "properties": {}},
892
                    },
893
                    "required": ["name", "parameters"],
894
895
                }

896
            def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict:
897
898
899
900
901
902
                all_defs = dict[str, dict[str, Any]]()
                for tool in tools:
                    if tool.function.parameters is None:
                        continue
                    defs = tool.function.parameters.pop("$defs", {})
                    for def_name, def_schema in defs.items():
903
                        if def_name in all_defs and all_defs[def_name] != def_schema:
904
905
906
                            raise ValueError(
                                f"Tool definition '{def_name}' has "
                                "multiple schemas, which is not "
907
908
                                "supported."
                            )
909
910
911
912
                        else:
                            all_defs[def_name] = def_schema
                return all_defs

913
914
915
916
917
            json_schema = {
                "type": "array",
                "minItems": 1,
                "items": {
                    "type": "object",
918
919
                    "anyOf": [get_tool_schema(tool) for tool in self.tools],
                },
920
            }
921
922
923
            json_schema_defs = get_tool_schema_defs(self.tools)
            if json_schema_defs:
                json_schema["$defs"] = json_schema_defs
924
925
            return json_schema

926
        return None
927

928
    @model_validator(mode="before")
929
    @classmethod
930
931
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
932
            raise ValueError("Stream options can only be defined when `stream=True`.")
933
934
935
936
937
938
939

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
940
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
941
                raise ValueError(
942
943
                    "`prompt_logprobs` are not available when `stream=True`."
                )
944

945
            if prompt_logprobs < 0 and prompt_logprobs != -1:
946
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
947
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
948
949
950
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
951
        if (top_logprobs := data.get("top_logprobs")) is not None:
952
            if top_logprobs < 0 and top_logprobs != -1:
953
                raise ValueError("`top_logprobs` must be a positive value or -1.")
954

955
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
956
957
958
959
960
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
961

962
963
    @model_validator(mode="before")
    @classmethod
964
    def check_structured_outputs_count(cls, data):
965
966
967
        if isinstance(data, ValueError):
            raise data

968
        if data.get("structured_outputs", None) is None:
969
970
            return data

971
        structured_outputs_kwargs = data["structured_outputs"]
972
973
        count = sum(
            structured_outputs_kwargs.get(k) is not None
974
975
            for k in ("json", "regex", "choice")
        )
976
977
        # you can only use one kind of constraints for structured outputs
        if count > 1:
978
            raise ValueError(
979
                "You can only use one kind of constraints for structured "
980
981
                "outputs ('json', 'regex' or 'choice')."
            )
982
983
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
984
985
986
            "none",
            "auto",
            "required",
987
        ):
988
            raise ValueError(
989
                "You can only either use constraints for structured outputs "
990
991
                "or tools, not both."
            )
992
993
994
995
        return data

    @model_validator(mode="before")
    @classmethod
996
997
998
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
999
        if "tool_choice" not in data and data.get("tools"):
1000
1001
            data["tool_choice"] = "auto"

1002
        # if "tool_choice" is "none" -- no validation is needed for tools
1003
1004
1005
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

1006
        # if "tool_choice" is specified -- validation
1007
        if "tool_choice" in data and data["tool_choice"] is not None:
1008
            # ensure that if "tool choice" is specified, tools are present
1009
            if "tools" not in data or data["tools"] is None:
1010
                raise ValueError("When using `tool_choice`, `tools` must be set.")
1011
1012

            # make sure that tool choice is either a named tool
1013
            # OR that it's set to "auto" or "required"
1014
1015
1016
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
1017
                raise ValueError(
1018
1019
1020
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
1021
                )
1022

1023
1024
1025
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
1026
1027
1028
1029
1030
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
1031
1032
1033
1034
                data["tool_choice"] = "none"
                del data["tools"]
                return data

1035
1036
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
1037
1038
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
1039
                ' "function": {"name": "my_function"}}`'
1040
            )
1041
1042
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
1043
1044
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
1045
                    raise ValueError(
1046
                        f"Invalid value for `function`: `{function}` in "
1047
1048
                        f"`tool_choice`! {correct_usage_message}"
                    )
1049
                if "name" not in function:
1050
1051
1052
1053
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
1054
                function_name = function["name"]
1055
                if not isinstance(function_name, str) or len(function_name) == 0:
1056
                    raise ValueError(
1057
                        f"Invalid `name` in `function`: `{function_name}`"
1058
1059
                        f" in `tool_choice`! {correct_usage_message}"
                    )
1060
                for tool in data["tools"]:
1061
                    if tool["function"]["name"] == function_name:
1062
1063
1064
1065
1066
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1067
1068
                        " of the specified `tools`"
                    )
1069
1070
        return data

1071
1072
1073
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1074
1075
1076
1077
1078
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1079
1080
        return data

1081
1082
1083
1084
1085
1086
1087
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1088
1089
1090
1091
1092
1093
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1094
1095
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1096

1097
class CompletionRequest(OpenAIBaseModel):
1098
1099
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1100
    model: Optional[str] = None
1101
    prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
1102
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
1103
1104
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
1105
    logit_bias: Optional[dict[str, float]] = None
1106
1107
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
1108
    n: int = 1
1109
    presence_penalty: Optional[float] = 0.0
1110
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
1111
    stop: Optional[Union[str, list[str]]] = []
1112
    stream: Optional[bool] = False
1113
    stream_options: Optional[StreamOptions] = None
1114
    suffix: Optional[str] = None
1115
1116
    temperature: Optional[float] = None
    top_p: Optional[float] = None
Zhuohan Li's avatar
Zhuohan Li committed
1117
    user: Optional[str] = None
1118

1119
    # --8<-- [start:completion-sampling-params]
1120
    use_beam_search: bool = False
1121
1122
1123
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
1124
    length_penalty: float = 1.0
1125
    stop_token_ids: Optional[list[int]] = []
1126
1127
1128
1129
1130
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1131
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1132
    allowed_token_ids: Optional[list[int]] = None
1133
    prompt_logprobs: Optional[int] = None
1134
    # --8<-- [end:completion-sampling-params]
1135

1136
    # --8<-- [start:completion-extra-params]
1137
    prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
1138
1139
    add_special_tokens: bool = Field(
        default=True,
1140
        description=(
1141
            "If true (the default), special tokens (e.g. BOS) will be added to "
1142
1143
            "the prompt."
        ),
1144
    )
1145
    response_format: Optional[AnyResponseFormat] = Field(
1146
        default=None,
1147
1148
1149
1150
1151
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1152
    )
1153
    structured_outputs: Optional[StructuredOutputsParams] = Field(
1154
        default=None,
1155
        description="Additional kwargs for structured outputs",
1156
    )
1157
1158
1159
1160
1161
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1162
1163
            "Please pass `json` to `structured_outputs` instead."
        ),
1164
1165
1166
1167
1168
1169
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1170
1171
            "Please pass `regex` to `structured_outputs` instead."
        ),
1172
1173
1174
1175
1176
1177
    )
    guided_choice: Optional[list[str]] = Field(
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1178
1179
            "Please pass `choice` to `structured_outputs` instead."
        ),
1180
1181
1182
1183
1184
1185
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1186
1187
            "Please pass `grammar` to `structured_outputs` instead."
        ),
1188
1189
1190
1191
1192
1193
    )
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1194
1195
            "Please remove it from your request."
        ),
1196
1197
1198
1199
1200
1201
1202
1203
1204
    )
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
1205
1206
1207
1208
1209
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1210
1211
            "if the served model does not use priority scheduling."
        ),
1212
    )
1213
1214
1215
1216
1217
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1218
1219
            "through out the inference process and return in response."
        ),
1220
    )
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1231
1232
1233
            "{'param': 'value'}}."
        ),
    )
1234

1235
1236
1237
1238
1239
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1240
1241
1242
            "that are not JSON-encodable can be identified."
        ),
    )
1243
1244
1245
1246
1247
1248
1249
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1250
1251
1252
            "need to map generated text back to input tokens."
        ),
    )
1253

1254
1255
1256
1257
1258
1259
1260
1261
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1262
1263
1264
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
1265

Robert Shaw's avatar
Robert Shaw committed
1266
1267
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
1268
1269
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1270

1271
1272
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
1273
1274
1275
1276
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1277
1278
    )

1279
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1280

1281
1282
1283
1284
1285
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1286
        "top_k": 0,
1287
1288
1289
1290
        "min_p": 0.0,
    }

    def to_beam_search_params(
1291
1292
1293
        self,
        max_tokens: int,
        default_sampling_params: Optional[dict] = None,
1294
1295
1296
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1297
        n = self.n if self.n is not None else 1
1298
1299
1300

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1301
1302
1303
1304
1305
1306

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1307
            length_penalty=self.length_penalty,
1308
1309
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1310

1311
    def to_sampling_params(
1312
        self,
1313
        max_tokens: int,
1314
1315
1316
        logits_processor_pattern: Optional[str],
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
1317
1318
        if default_sampling_params is None:
            default_sampling_params = {}
1319

1320
1321
1322
1323
1324
1325
1326
1327
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1328
1329
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1330
1331
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1332
1333
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1334
1335
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1336
1337
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1338
1339
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1340
1341
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1342

1343
1344
1345
1346
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1347
1348
        echo_without_generation = self.echo and self.max_tokens == 0

1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

1362
1363
1364
1365
1366
        if (
            self.structured_outputs is not None
            and self.response_format is not None
            and self.response_format.type == "json_object"
        ):
1367
            self.structured_outputs.json_object = True
1368

1369
1370
1371
1372
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1373
        return SamplingParams.from_optional(
1374
            n=self.n,
1375
            best_of=self.best_of,
1376
1377
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1378
1379
1380
1381
1382
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1383
            seed=self.seed,
1384
1385
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1386
            logprobs=self.logprobs,
1387
            ignore_eos=self.ignore_eos,
1388
            max_tokens=max_tokens if not echo_without_generation else 1,
1389
            min_tokens=self.min_tokens,
1390
            prompt_logprobs=prompt_logprobs,
1391
            skip_special_tokens=self.skip_special_tokens,
1392
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1393
            include_stop_str_in_output=self.include_stop_str_in_output,
1394
1395
1396
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1397
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1398
1399
1400
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1401
            structured_outputs=self.structured_outputs,
1402
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1403
            allowed_token_ids=self.allowed_token_ids,
1404
            extra_args=extra_args or None,
1405
        )
1406

1407
1408
    @model_validator(mode="before")
    @classmethod
1409
    def check_structured_outputs_count(cls, data):
1410
        if data.get("structured_outputs", None) is None:
1411
1412
            return data

1413
        structured_outputs_kwargs = data["structured_outputs"]
1414
1415
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1416
1417
            for k in ("json", "regex", "choice")
        )
1418
        if count > 1:
1419
            raise ValueError(
1420
                "You can only use one kind of constraints for structured "
1421
1422
                "outputs ('json', 'regex' or 'choice')."
            )
1423
1424
        return data

1425
1426
1427
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1428
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1429
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1430
                raise ValueError(
1431
1432
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1433

1434
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1435
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1436
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
1437
1438
1439
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
1440
1441
1442
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1443
1444
        return data

1445
1446
1447
1448
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1449
            raise ValueError("Stream options can only be defined when `stream=True`.")
1450

1451
1452
        return data

1453
1454
1455
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1456
1457
1458
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1459
1460
1461
1462
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1463
1464

        if prompt_is_empty and embeds_is_empty:
1465
            raise ValueError(
1466
1467
1468
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1469
1470
        return data

1471
1472
1473
1474
1475
1476
1477
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1478
1479
1480
1481
1482
1483
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1484
1485
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1486

1487
class EmbeddingCompletionRequest(OpenAIBaseModel):
1488
1489
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1490
    model: Optional[str] = None
1491
    input: Union[list[int], list[list[int]], str, list[str]]
1492
    encoding_format: Literal["float", "base64"] = "float"
1493
1494
    dimensions: Optional[int] = None
    user: Optional[str] = None
1495
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1496

1497
    # --8<-- [start:embedding-extra-params]
1498
1499
1500
1501
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1502
1503
            "the prompt."
        ),
1504
    )
1505
1506
1507
1508
1509
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1510
1511
            "if the served model does not use priority scheduling."
        ),
1512
    )
1513
1514
1515
1516
1517
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1518
1519
            "through out the inference process and return in response."
        ),
1520
    )
1521
    normalize: Optional[bool] = None
1522

1523
    # --8<-- [end:embedding-extra-params]
1524

1525
    def to_pooling_params(self):
1526
1527
1528
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1529
1530
            normalize=self.normalize,
        )
1531
1532


1533
class EmbeddingChatRequest(OpenAIBaseModel):
1534
    model: Optional[str] = None
1535
    messages: list[ChatCompletionMessageParam]
1536
1537
1538
1539

    encoding_format: Literal["float", "base64"] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
1540
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1541

1542
    # --8<-- [start:chat-embedding-extra-params]
1543
1544
    add_generation_prompt: bool = Field(
        default=False,
1545
1546
1547
1548
1549
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1550
1551
    )

1552
1553
1554
1555
1556
1557
1558
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1559
1560
            "default)."
        ),
1561
1562
1563
1564
1565
1566
1567
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1568
1569
            "does not define one."
        ),
1570
    )
1571
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
1572
        default=None,
1573
1574
        description=(
            "Additional keyword args to pass to the template renderer. "
1575
1576
            "Will be accessible by the chat template."
        ),
1577
    )
1578
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
1579
1580
1581
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1582
1583
1584
1585
1586
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1587
1588
            "if the served model does not use priority scheduling."
        ),
1589
    )
1590
1591
1592
1593
1594
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1595
1596
            "through out the inference process and return in response."
        ),
1597
    )
1598
    normalize: Optional[bool] = None
1599
    # --8<-- [end:chat-embedding-extra-params]
1600
1601
1602
1603

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1604
1605
1606
1607
1608
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1609
1610
1611
        return data

    def to_pooling_params(self):
1612
1613
1614
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1615
1616
            normalize=self.normalize,
        )
1617
1618
1619
1620


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

1621
1622
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640

T = TypeVar("T")


class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
    model: Optional[str] = None

    priority: int = Field(default=0)
    """
    The priority of the request (lower means earlier handling;
    default: 0). Any priority other than 0 will raise an error
    if the served model does not use priority scheduling.
    """
    data: T
    """
    When using plugins IOProcessor plugins, the actual input is processed
    by the plugin itself. Hence, we use a generic type for the request data
    """
1641
    softmax: bool = True
1642
1643

    def to_pooling_params(self):
1644
        return PoolingParams(task="encode", softmax=self.softmax)
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
    request_id: Optional[str] = None
    """
    The request_id associated with this response
    """
    created_at: int = Field(default_factory=lambda: int(time.time()))

    data: T
    """
    When using plugins IOProcessor plugins, the actual output is generated
    by the plugin itself. Hence, we use a generic type for the response data
    """


1661
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, IOProcessorRequest]
1662

1663

1664
class ScoreRequest(OpenAIBaseModel):
1665
    model: Optional[str] = None
1666
1667
    text_1: Union[list[str], str, ScoreMultiModalParam]
    text_2: Union[list[str], str, ScoreMultiModalParam]
1668
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1669

1670
    # --8<-- [start:score-extra-params]
1671
1672
1673
1674
1675
1676

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1677
1678
1679
1680
1681
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1682
1683
            "if the served model does not use priority scheduling."
        ),
1684
    )
1685

1686
1687
    activation: Optional[bool] = None

1688
    # --8<-- [end:score-extra-params]
1689

1690
    def to_pooling_params(self):
1691
1692
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1693
1694
            activation=self.activation,
        )
1695
1696


1697
class RerankRequest(OpenAIBaseModel):
1698
    model: Optional[str] = None
1699
1700
    query: Union[str, ScoreMultiModalParam]
    documents: Union[list[str], ScoreMultiModalParam]
1701
    top_n: int = Field(default_factory=lambda: 0)
1702
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1703

1704
    # --8<-- [start:rerank-extra-params]
1705
1706
1707
1708
1709
1710

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1711
1712
1713
1714
1715
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1716
1717
            "if the served model does not use priority scheduling."
        ),
1718
    )
1719

1720
1721
    activation: Optional[bool] = None

1722
    # --8<-- [end:rerank-extra-params]
1723

1724
    def to_pooling_params(self):
1725
1726
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1727
1728
            activation=self.activation,
        )
1729
1730
1731


class RerankDocument(BaseModel):
1732
    text: Optional[str] = None
1733
    multi_modal: Optional[ScoreContentPartParam] = None
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1750
    results: list[RerankResult]
1751
1752


1753
class CompletionLogProbs(OpenAIBaseModel):
1754
1755
1756
    text_offset: list[int] = Field(default_factory=list)
    token_logprobs: list[Optional[float]] = Field(default_factory=list)
    tokens: list[str] = Field(default_factory=list)
1757
    top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1758
1759


1760
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1761
1762
    index: int
    text: str
1763
    logprobs: Optional[CompletionLogProbs] = None
1764
1765
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1766
1767
1768
1769
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1770
1771
            "including encountering the EOS token"
        ),
1772
    )
1773
    token_ids: Optional[list[int]] = None  # For response
1774
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
1775
    prompt_token_ids: Optional[list[int]] = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1776
1777


1778
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1779
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1780
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1781
1782
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1783
    choices: list[CompletionResponseChoice]
1784
1785
1786
    service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = (
        None
    )
1787
    system_fingerprint: Optional[str] = None
Zhuohan Li's avatar
Zhuohan Li committed
1788
    usage: UsageInfo
1789
1790

    # vLLM-specific fields that are not in OpenAI spec
Robert Shaw's avatar
Robert Shaw committed
1791
    kv_transfer_params: Optional[dict[str, Any]] = Field(
1792
1793
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1794
1795


1796
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1797
1798
    index: int
    text: str
1799
    logprobs: Optional[CompletionLogProbs] = None
1800
1801
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1802
1803
1804
1805
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1806
1807
            "including encountering the EOS token"
        ),
1808
    )
1809
1810
1811
1812
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
    prompt_token_ids: Optional[list[int]] = None
    token_ids: Optional[list[int]] = None
Zhuohan Li's avatar
Zhuohan Li committed
1813
1814


1815
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1816
1817
1818
1819
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1820
    choices: list[CompletionResponseStreamChoice]
1821
    usage: Optional[UsageInfo] = Field(default=None)
1822
1823


1824
class EmbeddingResponseData(OpenAIBaseModel):
1825
1826
    index: int
    object: str = "embedding"
1827
    embedding: Union[list[float], str]
1828
1829


1830
class EmbeddingResponse(OpenAIBaseModel):
1831
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1832
1833
1834
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1835
    data: list[EmbeddingResponseData]
1836
1837
1838
    usage: UsageInfo


1839
1840
1841
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1842
    data: Union[list[list[float]], list[float], str]
1843
1844
1845
1846
1847
1848
1849


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1850
    data: list[PoolingResponseData]
1851
1852
1853
    usage: UsageInfo


1854
1855
1856
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1857
    score: float
1858
1859
1860
1861
1862
1863
1864


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1865
    data: list[ScoreResponseData]
1866
1867
1868
    usage: UsageInfo


1869
1870
1871
1872
1873
1874
class ClassificationRequest(OpenAIBaseModel):
    model: Optional[str] = None
    input: Union[list[str], str]
    truncate_prompt_tokens: Optional[int] = None
    user: Optional[str] = None

1875
    # --8<-- [start:classification-extra-params]
1876
1877
1878
1879
1880
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1881
1882
            "if the served model does not use priority scheduling."
        ),
1883
1884
    )

1885
1886
    activation: Optional[bool] = None

1887
    # --8<-- [end:classification-extra-params]
1888
1889

    def to_pooling_params(self):
1890
1891
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1892
1893
            activation=self.activation,
        )
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911


class ClassificationData(OpenAIBaseModel):
    index: int
    label: Optional[str]
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1912
1913
1914
1915
1916
1917
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1918
    id: str = Field(default_factory=make_tool_call_id)
1919
1920
1921
1922
    type: Literal["function"] = "function"
    function: FunctionCall


1923
1924
1925
1926
1927
1928
1929
class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1930
1931
    id: Optional[str] = None
    type: Optional[Literal["function"]] = None
1932
1933
1934
1935
1936
1937
1938
1939
1940
    index: int
    function: Optional[DeltaFunctionCall] = None


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1941
    tool_calls: list[ToolCall]
1942
1943
1944
1945
1946
1947

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


1948
class ChatMessage(OpenAIBaseModel):
1949
    role: str
1950
    content: Optional[str] = None
1951
1952
1953
1954
    refusal: Optional[str] = None
    annotations: Optional[OpenAIAnnotation] = None
    audio: Optional[OpenAIChatCompletionAudio] = None
    function_call: Optional[FunctionCall] = None
1955
    tool_calls: list[ToolCall] = Field(default_factory=list)
1956

1957
1958
1959
    # vLLM-specific fields that are not in OpenAI spec
    reasoning_content: Optional[str] = None

1960

1961
1962
1963
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1964
    bytes: Optional[list[int]] = None
1965
1966
1967


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1968
1969
1970
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
    field_names: ClassVar[Optional[set[str]]] = None
1971
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1972
1973
1974


class ChatCompletionLogProbs(OpenAIBaseModel):
1975
    content: Optional[list[ChatCompletionLogProbsContent]] = None
1976
1977


1978
class ChatCompletionResponseChoice(OpenAIBaseModel):
1979
1980
    index: int
    message: ChatMessage
1981
    logprobs: Optional[ChatCompletionLogProbs] = None
1982
1983
1984
    # per OpenAI spec this is the default
    finish_reason: Optional[str] = "stop"
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1985
    stop_reason: Optional[Union[int, str]] = None
1986
1987
1988
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
    token_ids: Optional[list[int]] = None
1989
1990


1991
class ChatCompletionResponse(OpenAIBaseModel):
1992
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1993
    object: Literal["chat.completion"] = "chat.completion"
1994
1995
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1996
    choices: list[ChatCompletionResponseChoice]
1997
1998
1999
    service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = (
        None
    )
2000
    system_fingerprint: Optional[str] = None
2001
    usage: UsageInfo
2002
2003

    # vLLM-specific fields that are not in OpenAI spec
2004
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
2005
    prompt_token_ids: Optional[list[int]] = None
Robert Shaw's avatar
Robert Shaw committed
2006
    kv_transfer_params: Optional[dict[str, Any]] = Field(
2007
2008
        default=None, description="KVTransfer parameters."
    )
2009
2010


2011
class DeltaMessage(OpenAIBaseModel):
2012
2013
    role: Optional[str] = None
    content: Optional[str] = None
2014
    reasoning_content: Optional[str] = None
2015
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
2016
2017


2018
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
2019
2020
    index: int
    delta: DeltaMessage
2021
    logprobs: Optional[ChatCompletionLogProbs] = None
2022
    finish_reason: Optional[str] = None
2023
    stop_reason: Optional[Union[int, str]] = None
2024
2025
    # not part of the OpenAI spec but for tracing the tokens
    token_ids: Optional[list[int]] = None
2026
2027


2028
class ChatCompletionStreamResponse(OpenAIBaseModel):
2029
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2030
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
2031
2032
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2033
    choices: list[ChatCompletionResponseStreamChoice]
2034
    usage: Optional[UsageInfo] = Field(default=None)
2035
2036
    # not part of the OpenAI spec but for tracing the tokens
    prompt_token_ids: Optional[list[int]] = None
2037
2038


2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


2054
2055
2056
2057
2058
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int


class OutputTokensDetails(OpenAIBaseModel):
2059
2060
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
2061
2062
2063
2064
2065
2066
2067
2068


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
2069
2070
2071
2072
2073
2074


class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
2075
    incomplete_details: Optional[IncompleteDetails] = None
2076
2077
2078
2079
    instructions: Optional[str] = None
    metadata: Optional[Metadata] = None
    model: str
    object: Literal["response"] = "response"
2080
    output: list[ResponseOutputItem]
2081
2082
2083
2084
2085
    # These are populated when enable_response_messages is set to True
    # TODO: Currently an issue where content of harmony messages
    # is not available when these are serialized. Metadata is available
    input_messages: Optional[list[ChatCompletionMessageParam]] = None
    output_messages: Optional[list[ChatCompletionMessageParam]] = None
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
    max_tool_calls: Optional[int] = None
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
    text: Optional[ResponseTextConfig] = None
2100
    top_logprobs: Optional[int] = None
2101
    truncation: Literal["auto", "disabled"]
2102
    usage: Optional[ResponseUsage] = None
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
    user: Optional[str] = None

    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
2114
        usage: Optional[ResponseUsage] = None,
2115
2116
        input_messages: Optional[list[ChatCompletionMessageParam]] = None,
        output_messages: Optional[list[ChatCompletionMessageParam]] = None,
2117
    ) -> "ResponsesResponse":
2118
        incomplete_details: Optional[IncompleteDetails] = None
2119
2120
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
2121
2122
2123
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
2124
2125
2126
        return cls(
            id=request.request_id,
            created_at=created_time,
2127
            incomplete_details=incomplete_details,
2128
2129
2130
2131
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
2132
2133
            input_messages=input_messages,
            output_messages=output_messages,
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


2213
StreamingResponsesResponse: TypeAlias = Union[
2214
2215
2216
    "ResponseCreatedEvent",
    "ResponseInProgressEvent",
    "ResponseCompletedEvent",
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseReasoningPartAddedEvent,
    ResponseReasoningPartDoneEvent,
    ResponseCodeInterpreterCallInProgressEvent,
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
    ResponseWebSearchCallCompletedEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseCodeInterpreterCallCompletedEvent,
]

2235
2236
2237
BatchRequestInputBody = Union[
    ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest
]
2238
2239


2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

2259
    # The parameters of the request.
2260
    body: BatchRequestInputBody
2261

2262
    @field_validator("body", mode="plain")
2263
2264
2265
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
2266
        url: str = info.data["url"]
2267
2268
2269
2270
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
2271
        if url.endswith("/score"):
2272
            return ScoreRequest.model_validate(value)
2273
2274
2275
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
2276

2277

2278
2279
2280
2281
2282
2283
2284
2285
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
2286
2287
2288
    body: Optional[
        Union[ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse]
    ] = None
2289
2290


2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

2302
    response: Optional[BatchResponseData]
2303
2304
2305
2306

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Optional[Any]
2307
2308


2309
class TokenizeCompletionRequest(OpenAIBaseModel):
2310
    model: Optional[str] = None
2311
2312
    prompt: str

2313
2314
2315
2316
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
2317
2318
            "the prompt."
        ),
2319
    )
2320
2321
    return_token_strs: Optional[bool] = Field(
        default=False,
2322
2323
2324
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2325
    )
2326
2327
2328


class TokenizeChatRequest(OpenAIBaseModel):
2329
    model: Optional[str] = None
2330
    messages: list[ChatCompletionMessageParam]
2331

2332
2333
    add_generation_prompt: bool = Field(
        default=True,
2334
2335
2336
2337
2338
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
2339
    )
2340
2341
    return_token_strs: Optional[bool] = Field(
        default=False,
2342
2343
2344
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2345
    )
2346
2347
    continue_final_message: bool = Field(
        default=False,
2348
2349
2350
2351
2352
2353
2354
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
2355
2356
2357
2358
2359
2360
2361
2362
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
2363
2364
            "default)."
        ),
2365
2366
2367
2368
2369
2370
2371
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
2372
2373
            "does not define one."
        ),
2374
    )
2375
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
2376
        default=None,
2377
2378
        description=(
            "Additional keyword args to pass to the template renderer. "
2379
2380
            "Will be accessible by the chat template."
        ),
2381
    )
2382
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
2383
2384
2385
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
2386
2387
2388
2389
    tools: Optional[list[ChatCompletionToolsParam]] = Field(
        default=None,
        description=("A list of tools the model may call."),
    )
2390

2391
2392
2393
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
2394
2395
2396
2397
2398
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
2399
2400
        return data

2401
2402

TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
2403
2404
2405
2406
2407


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2408
    tokens: list[int]
2409
    token_strs: Optional[list[str]] = None
2410
2411
2412


class DetokenizeRequest(OpenAIBaseModel):
2413
    model: Optional[str] = None
2414
    tokens: list[int]
2415
2416
2417
2418


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2419
2420


2421
2422
class TokenizerInfoResponse(OpenAIBaseModel):
    """
2423
    Response containing tokenizer configuration
2424
2425
2426
2427
2428
2429
2430
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2431
class LoadLoRAAdapterRequest(BaseModel):
2432
2433
2434
2435
    lora_name: str
    lora_path: str


2436
class UnloadLoRAAdapterRequest(BaseModel):
2437
2438
    lora_name: str
    lora_int_id: Optional[int] = Field(default=None)
2439
2440
2441


## Protocols for Audio
2442
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
2443
2444
2445
2446


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2447
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2448
2449
2450
2451
2452
2453
2454

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2455
    model: Optional[str] = None
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
    """ID of the model to use.
    """

    language: Optional[str] = None
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2483
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2484
2485
        alias="timestamp_granularities[]", default=[]
    )
2486
2487
2488
2489
2490
2491
2492
2493
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2494
    stream: Optional[bool] = False
2495
    """When set, it will enable output to be streamed in a similar fashion
2496
    as the Chat Completion endpoint.
2497
    """
2498
    # --8<-- [start:transcription-extra-params]
2499
2500
2501
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
2502
2503
2504

    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
2505
2506
2507
2508
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2509
    )
2510
    # --8<-- [end:transcription-extra-params]
2511

2512
2513
2514
    to_language: Optional[str] = None
    """The language of the output audio we transcribe to.

2515
    Please note that this is not currently used by supported models at this
2516
2517
2518
    time, but it is a placeholder for future use, matching translation api.
    """

2519
    # --8<-- [start:transcription-sampling-params]
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

    top_p: Optional[float] = None
2530
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2531
2532
2533
2534
2535
2536
2537
    smallest possible set whose cumulative probability exceeds `p`.
    """

    top_k: Optional[int] = None
    """Limits sampling to the `k` most probable tokens at each step."""

    min_p: Optional[float] = None
2538
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
    minimum likelihood threshold during sampling.
    """

    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

    frequency_penalty: Optional[float] = 0.0
    """The frequency penalty to use for sampling."""

    repetition_penalty: Optional[float] = None
    """The repetition penalty to use for sampling."""

    presence_penalty: Optional[float] = 0.0
    """The presence penalty to use for sampling."""
2553
    # --8<-- [end:transcription-sampling-params]
2554

2555
2556
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2557
2558
2559
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2560
        "top_k": 0,
2561
        "min_p": 0.0,
2562
2563
2564
    }

    def to_sampling_params(
2565
2566
        self, default_max_tokens: int, default_sampling_params: Optional[dict] = None
    ) -> SamplingParams:
2567
2568
2569
2570
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2571

2572
2573
2574
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2575
2576
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2577
2578
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2579
2580
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2581
2582
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2583
2584
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2585
2586
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2587
2588
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2589
2590
2591
2592

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2611
2612
2613

    @model_validator(mode="before")
    @classmethod
2614
2615
2616
2617
2618
2619
2620
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2621
2622
2623
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2624
            raise ValueError("Stream options can only be defined when `stream=True`.")
2625
2626

        return data
2627
2628
2629


# Transcription response objects
2630
2631
2632
2633
2634
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2635
2636
2637
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2638
    usage: TranscriptionUsageAudio
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2690
    tokens: list[int]
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2704
    segments: Optional[list[TranscriptionSegment]] = None
2705
2706
    """Segments of the transcribed text and their corresponding details."""

2707
    words: Optional[list[TranscriptionWord]] = None
2708
    """Extracted words and their corresponding timestamps."""
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

    model: Optional[str] = None
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2756
2757
2758
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
    language: Optional[str] = None
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2778
2779
2780
2781
2782
2783
2784
2785
    to_language: Optional[str] = None
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2786
    stream: Optional[bool] = False
2787
    """Custom field not present in the original OpenAI definition. When set,
2788
    it will enable output to be streamed in a similar fashion as the Chat
2789
    Completion endpoint.
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
    """
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2802
2803
        self, default_max_tokens: int, default_sampling_params: Optional[dict] = None
    ) -> SamplingParams:
2804
2805
2806
2807
2808
2809
2810
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2811
2812
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2813

2814
2815
2816
2817
2818
2819
2820
2821
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2822
2823
2824
2825
2826
2827
2828

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2829
            raise ValueError("Stream options can only be defined when `stream=True`.")
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

    segments: Optional[list[TranslationSegment]] = None
    """Segments of the translated text and their corresponding details."""

    words: Optional[list[TranslationWord]] = None
    """Extracted words and their corresponding timestamps."""