"tests/entrypoints/openai/test_disable_mp.py" did not exist on "8947bc3c156963dfc66e7ca1e4c436506ed6a512"
protocol.py 103 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, Union
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningItem,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
40
from openai.types.responses import (
41
42
43
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
44
from openai.types.responses import (
45
46
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
47
from openai.types.responses.response_reasoning_item import (
48
49
    Content as ResponseReasoningTextContent,
)
50
51
52
53
54

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
55
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
56

57
from openai.types.responses.response import IncompleteDetails, ToolChoice
58
59
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
60
61
62
63
64
65
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
    ValidationInfo,
66
    field_serializer,
67
68
69
    field_validator,
    model_validator,
)
70
from typing_extensions import TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
71

72
from vllm import envs
73
74
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
75
from vllm.logger import init_logger
76
from vllm.logprobs import Logprob
77
from vllm.pooling_params import PoolingParams
78
79
80
81
82
83
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
84
from vllm.utils import random_uuid, resolve_obj_by_qualname
85

86
87
logger = init_logger(__name__)

88
_LONG_INFO = torch.iinfo(torch.long)
89

Zhuohan Li's avatar
Zhuohan Li committed
90

91
class OpenAIBaseModel(BaseModel):
92
93
94
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

95
    # Cache class field names
96
    field_names: ClassVar[Optional[set[str]]] = None
97

98
    @model_validator(mode="wrap")
99
    @classmethod
100
101
102
103
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
104
105
        field_names = cls.field_names
        if field_names is None:
106
107
108
109
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
110
                if alias := getattr(field, "alias", None):
111
112
113
114
115
116
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
117
                "The following fields were present in the request but ignored: %s",
118
119
                data.keys() - field_names,
            )
120
        return result
121
122


123
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
124
125
126
    message: str
    type: str
    param: Optional[str] = None
127
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
128
129


130
131
132
133
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


134
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
135
136
137
138
139
140
141
142
143
144
145
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
146
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
147
148


149
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
150
151
152
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
153
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
154
155
    root: Optional[str] = None
    parent: Optional[str] = None
156
    max_model_len: Optional[int] = None
157
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
158
159


160
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
161
    object: str = "list"
162
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
163
164


165
166
167
168
class PromptTokenUsageInfo(OpenAIBaseModel):
    cached_tokens: Optional[int] = None


169
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
170
171
172
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0
173
    prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Zhuohan Li's avatar
Zhuohan Li committed
174
175


176
177
178
179
180
class RequestResponseMetadata(BaseModel):
    request_id: str
    final_usage_info: Optional[UsageInfo] = None


181
182
183
184
185
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
186
    json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
187
188
189
    strict: Optional[bool] = None


190
191
192
193
class StructuralTag(OpenAIBaseModel):
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
194
195
196
    structural_tag_schema: Optional[dict[str, Any]] = Field(
        default=None, alias="schema"
    )
197
198
199
200
201
202
203
204
205
    end: str


class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    structures: list[StructuralTag]
    triggers: list[str]


206
class ResponseFormat(OpenAIBaseModel):
207
    # type must be "json_schema", "json_object", or "text"
208
209
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchemaResponseFormat] = None
210
211


212
213
214
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]


215
class StreamOptions(OpenAIBaseModel):
216
    include_usage: Optional[bool] = True
217
    continuous_usage_stats: Optional[bool] = False
218
219


220
221
222
class FunctionDefinition(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
223
    parameters: Optional[dict[str, Any]] = None
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


240
241
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
242
243
class LogitsProcessorConstructor(BaseModel):
    qualname: str
244
245
    args: Optional[list[Any]] = None
    kwargs: Optional[dict[str, Any]] = None
246

247
248
    model_config = ConfigDict(extra="forbid")

249

250
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
251
252


253
254
255
def get_logits_processors(
    processors: Optional[LogitsProcessors], pattern: Optional[str]
) -> Optional[list[Any]]:
256
257
258
    if processors and pattern:
        logits_processors = []
        for processor in processors:
259
            qualname = processor if isinstance(processor, str) else processor.qualname
260
261
262
263
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
264
265
                    "for more information."
                )
266
267
268
269
270
271
272
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
273
274
275
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
276
277
278
279
280
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
281
            "server. See --logits-processor-pattern engine argument "
282
283
            "for more information."
        )
284
285
286
    return None


287
288
289
ResponseInputOutputItem: TypeAlias = Union[
    ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall
]
290
291


292
293
294
295
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
    background: Optional[bool] = False
296
297
298
299
300
301
302
303
304
305
306
307
    include: Optional[
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
    ] = None
308
    input: Union[str, list[ResponseInputOutputItem]]
309
310
311
312
313
314
315
316
317
    instructions: Optional[str] = None
    max_output_tokens: Optional[int] = None
    max_tool_calls: Optional[int] = None
    metadata: Optional[Metadata] = None
    model: Optional[str] = None
    parallel_tool_calls: Optional[bool] = True
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
318
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    store: Optional[bool] = True
    stream: Optional[bool] = False
    temperature: Optional[float] = None
    text: Optional[ResponseTextConfig] = None
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
    top_logprobs: Optional[int] = 0
    top_p: Optional[float] = None
    truncation: Optional[Literal["auto", "disabled"]] = "disabled"
    user: Optional[str] = None

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
336
337
            "through out the inference process and return in response."
        ),
338
339
340
341
342
343
344
345
346
347
    )
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
348
349
            "if the served model does not use priority scheduling."
        ),
350
    )
351
352
353
354
355
356
357
358
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
359
360
361
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
362
363
364
365
366
367

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
            "response object. Currently only supported for non-streaming "
368
369
370
            "non-background and gpt-oss only. "
        ),
    )
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
391
392
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
393
394
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
395
396
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
397
        stop_token_ids = default_sampling_params.get("stop_token_ids")
398
399

        # Structured output
400
        structured_outputs = None
401
402
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
403
404
405
406
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
407
                structured_outputs = StructuredOutputsParams(
408
409
                    json=response_format.schema_
                )
410
411
412
413
414
415
416
417
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
418
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
419
            stop_token_ids=stop_token_ids,
420
421
422
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
423
            structured_outputs=structured_outputs,
424
425
        )

426
427
428
429
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
430
431
432
433
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
434

435
436
437
438
439
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
440
            raise ValueError("background can only be used when `store` is true")
441
442
443
444
445
446
447
448
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

449
450
451
452
453
454
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
455
456
457
458
459
460
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
461
462
        return data

463

464
class ChatCompletionRequest(OpenAIBaseModel):
465
466
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
467
    messages: list[ChatCompletionMessageParam]
468
    model: Optional[str] = None
469
    frequency_penalty: Optional[float] = 0.0
470
    logit_bias: Optional[dict[str, float]] = None
471
    logprobs: Optional[bool] = False
472
    top_logprobs: Optional[int] = 0
473
474
    max_tokens: Optional[int] = Field(
        default=None,
475
476
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
477
    )
478
    max_completion_tokens: Optional[int] = None
479
480
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
481
    response_format: Optional[AnyResponseFormat] = None
482
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
483
    stop: Optional[Union[str, list[str]]] = []
Zhuohan Li's avatar
Zhuohan Li committed
484
    stream: Optional[bool] = False
485
    stream_options: Optional[StreamOptions] = None
486
487
    temperature: Optional[float] = None
    top_p: Optional[float] = None
488
    tools: Optional[list[ChatCompletionToolsParam]] = None
489
490
491
492
493
494
495
496
    tool_choice: Optional[
        Union[
            Literal["none"],
            Literal["auto"],
            Literal["required"],
            ChatCompletionNamedToolChoiceParam,
        ]
    ] = "none"
497
498
    reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
    include_reasoning: bool = True
499

500
    # NOTE this will be ignored by vLLM -- the model determines the behavior
501
    parallel_tool_calls: Optional[bool] = False
Zhuohan Li's avatar
Zhuohan Li committed
502
    user: Optional[str] = None
503

504
    # --8<-- [start:chat-completion-sampling-params]
505
    best_of: Optional[int] = None
506
    use_beam_search: bool = False
507
508
509
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
510
    length_penalty: float = 1.0
511
    stop_token_ids: Optional[list[int]] = []
512
513
514
515
516
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
517
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
518
    prompt_logprobs: Optional[int] = None
519
    allowed_token_ids: Optional[list[int]] = None
520
    bad_words: list[str] = Field(default_factory=list)
521
    # --8<-- [end:chat-completion-sampling-params]
522

523
    # --8<-- [start:chat-completion-extra-params]
524
    echo: bool = Field(
525
526
527
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
528
529
            "if they belong to the same role."
        ),
530
    )
531
    add_generation_prompt: bool = Field(
532
        default=True,
533
534
535
536
537
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
538
    )
539
540
    continue_final_message: bool = Field(
        default=False,
541
542
543
544
545
546
547
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
548
    )
549
    add_special_tokens: bool = Field(
550
551
552
553
554
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
555
            "special tokens so this should be set to false (as is the "
556
557
            "default)."
        ),
558
    )
559
    documents: Optional[list[dict[str, str]]] = Field(
560
        default=None,
561
562
563
564
565
566
567
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
568
569
570
571
572
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
573
574
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
575
576
            "does not define one."
        ),
577
    )
578
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
579
        default=None,
580
581
        description=(
            "Additional keyword args to pass to the template renderer. "
582
583
            "Will be accessible by the chat template."
        ),
584
    )
585
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
586
587
588
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
589
    structured_outputs: Optional[StructuredOutputsParams] = Field(
590
        default=None,
591
        description="Additional kwargs for structured outputs",
592
    )
593
594
595
596
597
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
598
599
            "Please pass `json` to `structured_outputs` instead."
        ),
600
601
602
603
604
605
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
606
607
            "Please pass `regex` to `structured_outputs` instead."
        ),
608
609
610
611
612
613
    )
    guided_choice: Optional[list[str]] = Field(
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
614
615
            "Please pass `choice` to `structured_outputs` instead."
        ),
616
617
618
619
620
621
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
622
623
            "Please pass `grammar` to `structured_outputs` instead."
        ),
624
625
626
627
628
629
    )
    structural_tag: Optional[str] = Field(
        default=None,
        description=(
            "`structural_tag` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
630
631
            "Please pass `structural_tag` to `structured_outputs` instead."
        ),
632
633
634
635
636
637
    )
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
638
639
            "Please remove it from your request."
        ),
640
641
642
643
644
645
646
647
648
    )
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
649
650
651
652
653
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
654
655
            "if the served model does not use priority scheduling."
        ),
656
    )
657
658
659
660
661
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
662
663
            "through out the inference process and return in response."
        ),
664
    )
665
666
667
668
669
670
671
672
673
674
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
675
676
677
            "{'param': 'value'}}."
        ),
    )
678
679
680
681
682
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
683
684
685
            "that are not JSON-encodable can be identified."
        ),
    )
686
687
688
689
690
691
692
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
693
694
695
            "need to map generated text back to input tokens."
        ),
    )
696
697
698
699
700
701
702
703
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
704
705
706
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
Robert Shaw's avatar
Robert Shaw committed
707
708
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
709
710
        description="KVTransfer parameters used for disaggregated serving.",
    )
711

712
713
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
714
715
716
717
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
718
719
    )

720
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
721

722
723
724
725
726
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
727
        "top_k": 0,
728
729
730
731
        "min_p": 0.0,
    }

    def to_beam_search_params(
732
733
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
734
        n = self.n if self.n is not None else 1
735
736
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
737
738
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
739
740
741
742
743
744

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
745
            length_penalty=self.length_penalty,
746
747
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
748

749
    def to_sampling_params(
750
        self,
751
        max_tokens: int,
752
        logits_processor_pattern: Optional[str],
753
        default_sampling_params: dict,
754
    ) -> SamplingParams:
755
756
757
758
759
760
761
762
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
763
764
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
765
766
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
767
768
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
769
770
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
771
772
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
773
774
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
775
776
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
777

778
779
780
781
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

782
783
784
785
786
787
788
789
790
791
792
793
794
795
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
                structural_tag=self.structural_tag,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
        response_format = self.response_format
        json_schema_from_tool = self._get_json_schema_from_tool()
        if response_format is not None or json_schema_from_tool is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format is not None:
                if response_format.type == "json_object":
                    self.structured_outputs.json_object = True
                elif response_format.type == "json_schema":
                    json_schema = response_format.json_schema
                    assert json_schema is not None
                    self.structured_outputs.json = json_schema.json_schema
                elif response_format.type == "structural_tag":
                    structural_tag = response_format
                    assert structural_tag is not None and isinstance(
815
816
                        structural_tag, StructuralTagResponseFormat
                    )
817
                    s_tag_obj = structural_tag.model_dump(by_alias=True)
818
                    self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
819
820
821
822

            # Set structured output params for tool calling
            if json_schema_from_tool is not None:
                self.structured_outputs.json = json_schema_from_tool
823

824
825
826
827
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
828
        return SamplingParams.from_optional(
829
            n=self.n,
830
            best_of=self.best_of,
831
832
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
833
834
835
836
837
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
838
            seed=self.seed,
839
840
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
841
            logprobs=self.top_logprobs if self.logprobs else None,
842
            prompt_logprobs=prompt_logprobs,
843
            ignore_eos=self.ignore_eos,
844
            max_tokens=max_tokens,
845
            min_tokens=self.min_tokens,
846
847
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
848
849
850
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
851
            include_stop_str_in_output=self.include_stop_str_in_output,
852
            truncate_prompt_tokens=self.truncate_prompt_tokens,
853
854
855
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
856
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
857
            logit_bias=self.logit_bias,
858
            bad_words=self.bad_words,
859
            allowed_token_ids=self.allowed_token_ids,
860
861
            extra_args=extra_args or None,
        )
862

863
    def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]:
864
865
866
867
868
869
870
871
872
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
873
                raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
874
875
876
            tool = tools[tool_name]
            return tool.parameters

877
878
879
880
881
882
883
884
        if self.tool_choice == "required":
            # Pydantic schema generation cannot be used since the JSON schema
            # has to be constructed for a specific instantiation of a tool list
            # so that parameters of a function are correctly generated
            # based on the chosen function name
            def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
                return {
                    "properties": {
885
                        "name": {"type": "string", "enum": [tool.function.name]},
886
887
888
889
890
                        # parameters are always generated as '{}' in the final
                        # output if they are missing from the request
                        # (i.e. are None or '{}') so the schema is
                        # updated to produce an empty object in that case
                        "parameters": tool.function.parameters
891
892
                        if tool.function.parameters
                        else {"type": "object", "properties": {}},
893
                    },
894
                    "required": ["name", "parameters"],
895
896
                }

897
            def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict:
898
899
900
901
902
903
                all_defs = dict[str, dict[str, Any]]()
                for tool in tools:
                    if tool.function.parameters is None:
                        continue
                    defs = tool.function.parameters.pop("$defs", {})
                    for def_name, def_schema in defs.items():
904
                        if def_name in all_defs and all_defs[def_name] != def_schema:
905
906
907
                            raise ValueError(
                                f"Tool definition '{def_name}' has "
                                "multiple schemas, which is not "
908
909
                                "supported."
                            )
910
911
912
913
                        else:
                            all_defs[def_name] = def_schema
                return all_defs

914
915
916
917
918
            json_schema = {
                "type": "array",
                "minItems": 1,
                "items": {
                    "type": "object",
919
920
                    "anyOf": [get_tool_schema(tool) for tool in self.tools],
                },
921
            }
922
923
924
            json_schema_defs = get_tool_schema_defs(self.tools)
            if json_schema_defs:
                json_schema["$defs"] = json_schema_defs
925
926
            return json_schema

927
        return None
928

929
    @model_validator(mode="before")
930
    @classmethod
931
932
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
933
            raise ValueError("Stream options can only be defined when `stream=True`.")
934
935
936
937
938
939
940

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
941
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
942
                raise ValueError(
943
944
                    "`prompt_logprobs` are not available when `stream=True`."
                )
945

946
            if prompt_logprobs < 0 and prompt_logprobs != -1:
947
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
948
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
949
950
951
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
952
        if (top_logprobs := data.get("top_logprobs")) is not None:
953
            if top_logprobs < 0 and top_logprobs != -1:
954
                raise ValueError("`top_logprobs` must be a positive value or -1.")
955

956
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
957
958
959
960
961
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
962

963
964
    @model_validator(mode="before")
    @classmethod
965
    def check_structured_outputs_count(cls, data):
966
967
968
        if isinstance(data, ValueError):
            raise data

969
        if data.get("structured_outputs", None) is None:
970
971
            return data

972
        structured_outputs_kwargs = data["structured_outputs"]
973
974
        count = sum(
            structured_outputs_kwargs.get(k) is not None
975
976
            for k in ("json", "regex", "choice")
        )
977
978
        # you can only use one kind of constraints for structured outputs
        if count > 1:
979
            raise ValueError(
980
                "You can only use one kind of constraints for structured "
981
982
                "outputs ('json', 'regex' or 'choice')."
            )
983
984
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
985
986
987
            "none",
            "auto",
            "required",
988
        ):
989
            raise ValueError(
990
                "You can only either use constraints for structured outputs "
991
992
                "or tools, not both."
            )
993
994
995
996
        return data

    @model_validator(mode="before")
    @classmethod
997
998
999
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
1000
        if "tool_choice" not in data and data.get("tools"):
1001
1002
            data["tool_choice"] = "auto"

1003
        # if "tool_choice" is "none" -- no validation is needed for tools
1004
1005
1006
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

1007
        # if "tool_choice" is specified -- validation
1008
        if "tool_choice" in data and data["tool_choice"] is not None:
1009
            # ensure that if "tool choice" is specified, tools are present
1010
            if "tools" not in data or data["tools"] is None:
1011
                raise ValueError("When using `tool_choice`, `tools` must be set.")
1012
1013

            # make sure that tool choice is either a named tool
1014
            # OR that it's set to "auto" or "required"
1015
1016
1017
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
1018
                raise ValueError(
1019
1020
1021
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
1022
                )
1023

1024
1025
1026
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
1027
1028
1029
1030
1031
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
1032
1033
1034
1035
                data["tool_choice"] = "none"
                del data["tools"]
                return data

1036
1037
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
1038
1039
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
1040
                ' "function": {"name": "my_function"}}`'
1041
            )
1042
1043
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
1044
1045
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
1046
                    raise ValueError(
1047
                        f"Invalid value for `function`: `{function}` in "
1048
1049
                        f"`tool_choice`! {correct_usage_message}"
                    )
1050
                if "name" not in function:
1051
1052
1053
1054
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
1055
                function_name = function["name"]
1056
                if not isinstance(function_name, str) or len(function_name) == 0:
1057
                    raise ValueError(
1058
                        f"Invalid `name` in `function`: `{function_name}`"
1059
1060
                        f" in `tool_choice`! {correct_usage_message}"
                    )
1061
                for tool in data["tools"]:
1062
                    if tool["function"]["name"] == function_name:
1063
1064
1065
1066
1067
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1068
1069
                        " of the specified `tools`"
                    )
1070
1071
        return data

1072
1073
1074
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1075
1076
1077
1078
1079
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1080
1081
        return data

1082
1083
1084
1085
1086
1087
1088
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1089
1090
1091
1092
1093
1094
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1095
1096
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1097

1098
class CompletionRequest(OpenAIBaseModel):
1099
1100
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1101
    model: Optional[str] = None
1102
    prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
1103
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
1104
1105
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
1106
    logit_bias: Optional[dict[str, float]] = None
1107
1108
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
1109
    n: int = 1
1110
    presence_penalty: Optional[float] = 0.0
1111
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
1112
    stop: Optional[Union[str, list[str]]] = []
1113
    stream: Optional[bool] = False
1114
    stream_options: Optional[StreamOptions] = None
1115
    suffix: Optional[str] = None
1116
1117
    temperature: Optional[float] = None
    top_p: Optional[float] = None
Zhuohan Li's avatar
Zhuohan Li committed
1118
    user: Optional[str] = None
1119

1120
    # --8<-- [start:completion-sampling-params]
1121
    use_beam_search: bool = False
1122
1123
1124
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
1125
    length_penalty: float = 1.0
1126
    stop_token_ids: Optional[list[int]] = []
1127
1128
1129
1130
1131
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1132
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1133
    allowed_token_ids: Optional[list[int]] = None
1134
    prompt_logprobs: Optional[int] = None
1135
    # --8<-- [end:completion-sampling-params]
1136

1137
    # --8<-- [start:completion-extra-params]
1138
    prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
1139
1140
    add_special_tokens: bool = Field(
        default=True,
1141
        description=(
1142
            "If true (the default), special tokens (e.g. BOS) will be added to "
1143
1144
            "the prompt."
        ),
1145
    )
1146
    response_format: Optional[AnyResponseFormat] = Field(
1147
        default=None,
1148
1149
1150
1151
1152
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1153
    )
1154
    structured_outputs: Optional[StructuredOutputsParams] = Field(
1155
        default=None,
1156
        description="Additional kwargs for structured outputs",
1157
    )
1158
1159
1160
1161
1162
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1163
1164
            "Please pass `json` to `structured_outputs` instead."
        ),
1165
1166
1167
1168
1169
1170
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1171
1172
            "Please pass `regex` to `structured_outputs` instead."
        ),
1173
1174
1175
1176
1177
1178
    )
    guided_choice: Optional[list[str]] = Field(
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1179
1180
            "Please pass `choice` to `structured_outputs` instead."
        ),
1181
1182
1183
1184
1185
1186
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1187
1188
            "Please pass `grammar` to `structured_outputs` instead."
        ),
1189
1190
1191
1192
1193
1194
    )
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1195
1196
            "Please remove it from your request."
        ),
1197
1198
1199
1200
1201
1202
1203
1204
1205
    )
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
1206
1207
1208
1209
1210
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1211
1212
            "if the served model does not use priority scheduling."
        ),
1213
    )
1214
1215
1216
1217
1218
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1219
1220
            "through out the inference process and return in response."
        ),
1221
    )
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1232
1233
1234
            "{'param': 'value'}}."
        ),
    )
1235

1236
1237
1238
1239
1240
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1241
1242
1243
            "that are not JSON-encodable can be identified."
        ),
    )
1244
1245
1246
1247
1248
1249
1250
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1251
1252
1253
            "need to map generated text back to input tokens."
        ),
    )
1254

1255
1256
1257
1258
1259
1260
1261
1262
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1263
1264
1265
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
1266

Robert Shaw's avatar
Robert Shaw committed
1267
1268
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
1269
1270
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1271

1272
1273
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
1274
1275
1276
1277
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1278
1279
    )

1280
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1281

1282
1283
1284
1285
1286
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1287
        "top_k": 0,
1288
1289
1290
1291
        "min_p": 0.0,
    }

    def to_beam_search_params(
1292
1293
1294
        self,
        max_tokens: int,
        default_sampling_params: Optional[dict] = None,
1295
1296
1297
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1298
        n = self.n if self.n is not None else 1
1299
1300
1301

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1302
1303
1304
1305
1306
1307

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1308
            length_penalty=self.length_penalty,
1309
1310
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1311

1312
    def to_sampling_params(
1313
        self,
1314
        max_tokens: int,
1315
1316
1317
        logits_processor_pattern: Optional[str],
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
1318
1319
        if default_sampling_params is None:
            default_sampling_params = {}
1320

1321
1322
1323
1324
1325
1326
1327
1328
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1329
1330
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1331
1332
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1333
1334
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1335
1336
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1337
1338
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1339
1340
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1341
1342
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1343

1344
1345
1346
1347
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1348
1349
        echo_without_generation = self.echo and self.max_tokens == 0

1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

1363
1364
1365
1366
1367
        if (
            self.structured_outputs is not None
            and self.response_format is not None
            and self.response_format.type == "json_object"
        ):
1368
            self.structured_outputs.json_object = True
1369

1370
1371
1372
1373
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1374
        return SamplingParams.from_optional(
1375
            n=self.n,
1376
            best_of=self.best_of,
1377
1378
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1379
1380
1381
1382
1383
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1384
            seed=self.seed,
1385
1386
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1387
            logprobs=self.logprobs,
1388
            ignore_eos=self.ignore_eos,
1389
            max_tokens=max_tokens if not echo_without_generation else 1,
1390
            min_tokens=self.min_tokens,
1391
            prompt_logprobs=prompt_logprobs,
1392
            skip_special_tokens=self.skip_special_tokens,
1393
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1394
            include_stop_str_in_output=self.include_stop_str_in_output,
1395
1396
1397
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1398
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1399
1400
1401
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1402
            structured_outputs=self.structured_outputs,
1403
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1404
            allowed_token_ids=self.allowed_token_ids,
1405
            extra_args=extra_args or None,
1406
        )
1407

1408
1409
    @model_validator(mode="before")
    @classmethod
1410
    def check_structured_outputs_count(cls, data):
1411
        if data.get("structured_outputs", None) is None:
1412
1413
            return data

1414
        structured_outputs_kwargs = data["structured_outputs"]
1415
1416
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1417
1418
            for k in ("json", "regex", "choice")
        )
1419
        if count > 1:
1420
            raise ValueError(
1421
                "You can only use one kind of constraints for structured "
1422
1423
                "outputs ('json', 'regex' or 'choice')."
            )
1424
1425
        return data

1426
1427
1428
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1429
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1430
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1431
                raise ValueError(
1432
1433
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1434

1435
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1436
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1437
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
1438
1439
1440
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
1441
1442
1443
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1444
1445
        return data

1446
1447
1448
1449
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1450
            raise ValueError("Stream options can only be defined when `stream=True`.")
1451

1452
1453
        return data

1454
1455
1456
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1457
1458
1459
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1460
1461
1462
1463
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1464
1465

        if prompt_is_empty and embeds_is_empty:
1466
            raise ValueError(
1467
1468
1469
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1470
1471
        return data

1472
1473
1474
1475
1476
1477
1478
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1479
1480
1481
1482
1483
1484
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1485
1486
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1487

1488
class EmbeddingCompletionRequest(OpenAIBaseModel):
1489
1490
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1491
    model: Optional[str] = None
1492
    input: Union[list[int], list[list[int]], str, list[str]]
1493
    encoding_format: Literal["float", "base64"] = "float"
1494
1495
    dimensions: Optional[int] = None
    user: Optional[str] = None
1496
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1497

1498
    # --8<-- [start:embedding-extra-params]
1499
1500
1501
1502
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1503
1504
            "the prompt."
        ),
1505
    )
1506
1507
1508
1509
1510
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1511
1512
            "if the served model does not use priority scheduling."
        ),
1513
    )
1514
1515
1516
1517
1518
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1519
1520
            "through out the inference process and return in response."
        ),
1521
    )
1522
    normalize: Optional[bool] = None
1523

1524
    # --8<-- [end:embedding-extra-params]
1525

1526
    def to_pooling_params(self):
1527
1528
1529
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1530
1531
            normalize=self.normalize,
        )
1532
1533


1534
class EmbeddingChatRequest(OpenAIBaseModel):
1535
    model: Optional[str] = None
1536
    messages: list[ChatCompletionMessageParam]
1537
1538
1539
1540

    encoding_format: Literal["float", "base64"] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
1541
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1542

1543
    # --8<-- [start:chat-embedding-extra-params]
1544
1545
    add_generation_prompt: bool = Field(
        default=False,
1546
1547
1548
1549
1550
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1551
1552
    )

1553
1554
1555
1556
1557
1558
1559
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1560
1561
            "default)."
        ),
1562
1563
1564
1565
1566
1567
1568
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1569
1570
            "does not define one."
        ),
1571
    )
1572
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
1573
        default=None,
1574
1575
        description=(
            "Additional keyword args to pass to the template renderer. "
1576
1577
            "Will be accessible by the chat template."
        ),
1578
    )
1579
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
1580
1581
1582
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1583
1584
1585
1586
1587
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1588
1589
            "if the served model does not use priority scheduling."
        ),
1590
    )
1591
1592
1593
1594
1595
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1596
1597
            "through out the inference process and return in response."
        ),
1598
    )
1599
    normalize: Optional[bool] = None
1600
    # --8<-- [end:chat-embedding-extra-params]
1601
1602
1603
1604

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1605
1606
1607
1608
1609
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1610
1611
1612
        return data

    def to_pooling_params(self):
1613
1614
1615
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1616
1617
            normalize=self.normalize,
        )
1618
1619
1620
1621


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

1622
1623
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641

T = TypeVar("T")


class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
    model: Optional[str] = None

    priority: int = Field(default=0)
    """
    The priority of the request (lower means earlier handling;
    default: 0). Any priority other than 0 will raise an error
    if the served model does not use priority scheduling.
    """
    data: T
    """
    When using plugins IOProcessor plugins, the actual input is processed
    by the plugin itself. Hence, we use a generic type for the request data
    """
1642
    softmax: bool = True
1643
1644

    def to_pooling_params(self):
1645
        return PoolingParams(task="encode", softmax=self.softmax)
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
    request_id: Optional[str] = None
    """
    The request_id associated with this response
    """
    created_at: int = Field(default_factory=lambda: int(time.time()))

    data: T
    """
    When using plugins IOProcessor plugins, the actual output is generated
    by the plugin itself. Hence, we use a generic type for the response data
    """


1662
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, IOProcessorRequest]
1663

1664

1665
class ScoreRequest(OpenAIBaseModel):
1666
    model: Optional[str] = None
1667
1668
    text_1: Union[list[str], str, ScoreMultiModalParam]
    text_2: Union[list[str], str, ScoreMultiModalParam]
1669
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1670

1671
    # --8<-- [start:score-extra-params]
1672
1673
1674
1675
1676
1677

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1678
1679
1680
1681
1682
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1683
1684
            "if the served model does not use priority scheduling."
        ),
1685
    )
1686

1687
1688
    activation: Optional[bool] = None

1689
    # --8<-- [end:score-extra-params]
1690

1691
    def to_pooling_params(self):
1692
1693
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1694
1695
            activation=self.activation,
        )
1696
1697


1698
class RerankRequest(OpenAIBaseModel):
1699
    model: Optional[str] = None
1700
1701
    query: Union[str, ScoreMultiModalParam]
    documents: Union[list[str], ScoreMultiModalParam]
1702
    top_n: int = Field(default_factory=lambda: 0)
1703
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1704

1705
    # --8<-- [start:rerank-extra-params]
1706
1707
1708
1709
1710
1711

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1712
1713
1714
1715
1716
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1717
1718
            "if the served model does not use priority scheduling."
        ),
1719
    )
1720

1721
1722
    activation: Optional[bool] = None

1723
    # --8<-- [end:rerank-extra-params]
1724

1725
    def to_pooling_params(self):
1726
1727
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1728
1729
            activation=self.activation,
        )
1730
1731
1732


class RerankDocument(BaseModel):
1733
    text: Optional[str] = None
1734
    multi_modal: Optional[ScoreContentPartParam] = None
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1751
    results: list[RerankResult]
1752
1753


1754
class CompletionLogProbs(OpenAIBaseModel):
1755
1756
1757
    text_offset: list[int] = Field(default_factory=list)
    token_logprobs: list[Optional[float]] = Field(default_factory=list)
    tokens: list[str] = Field(default_factory=list)
1758
    top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1759
1760


1761
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1762
1763
    index: int
    text: str
1764
    logprobs: Optional[CompletionLogProbs] = None
1765
1766
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1767
1768
1769
1770
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1771
1772
            "including encountering the EOS token"
        ),
1773
    )
1774
    token_ids: Optional[list[int]] = None  # For response
1775
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
1776
    prompt_token_ids: Optional[list[int]] = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1777
1778


1779
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1780
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1781
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1782
1783
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1784
    choices: list[CompletionResponseChoice]
1785
1786
1787
    service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = (
        None
    )
1788
    system_fingerprint: Optional[str] = None
Zhuohan Li's avatar
Zhuohan Li committed
1789
    usage: UsageInfo
1790
1791

    # vLLM-specific fields that are not in OpenAI spec
Robert Shaw's avatar
Robert Shaw committed
1792
    kv_transfer_params: Optional[dict[str, Any]] = Field(
1793
1794
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1795
1796


1797
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1798
1799
    index: int
    text: str
1800
    logprobs: Optional[CompletionLogProbs] = None
1801
1802
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1803
1804
1805
1806
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1807
1808
            "including encountering the EOS token"
        ),
1809
    )
1810
1811
1812
1813
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
    prompt_token_ids: Optional[list[int]] = None
    token_ids: Optional[list[int]] = None
Zhuohan Li's avatar
Zhuohan Li committed
1814
1815


1816
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1817
1818
1819
1820
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1821
    choices: list[CompletionResponseStreamChoice]
1822
    usage: Optional[UsageInfo] = Field(default=None)
1823
1824


1825
class EmbeddingResponseData(OpenAIBaseModel):
1826
1827
    index: int
    object: str = "embedding"
1828
    embedding: Union[list[float], str]
1829
1830


1831
class EmbeddingResponse(OpenAIBaseModel):
1832
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1833
1834
1835
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1836
    data: list[EmbeddingResponseData]
1837
1838
1839
    usage: UsageInfo


1840
1841
1842
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1843
    data: Union[list[list[float]], list[float], str]
1844
1845
1846
1847
1848
1849
1850


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1851
    data: list[PoolingResponseData]
1852
1853
1854
    usage: UsageInfo


1855
1856
1857
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1858
    score: float
1859
1860
1861
1862
1863
1864
1865


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1866
    data: list[ScoreResponseData]
1867
1868
1869
    usage: UsageInfo


1870
1871
1872
1873
1874
1875
class ClassificationRequest(OpenAIBaseModel):
    model: Optional[str] = None
    input: Union[list[str], str]
    truncate_prompt_tokens: Optional[int] = None
    user: Optional[str] = None

1876
    # --8<-- [start:classification-extra-params]
1877
1878
1879
1880
1881
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1882
1883
            "if the served model does not use priority scheduling."
        ),
1884
1885
    )

1886
1887
    activation: Optional[bool] = None

1888
    # --8<-- [end:classification-extra-params]
1889
1890

    def to_pooling_params(self):
1891
1892
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1893
1894
            activation=self.activation,
        )
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912


class ClassificationData(OpenAIBaseModel):
    index: int
    label: Optional[str]
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1913
1914
1915
1916
1917
1918
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1919
    id: str = Field(default_factory=make_tool_call_id)
1920
1921
1922
1923
    type: Literal["function"] = "function"
    function: FunctionCall


1924
1925
1926
1927
1928
1929
1930
class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1931
1932
    id: Optional[str] = None
    type: Optional[Literal["function"]] = None
1933
1934
1935
1936
1937
1938
1939
1940
1941
    index: int
    function: Optional[DeltaFunctionCall] = None


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1942
    tool_calls: list[ToolCall]
1943
1944
1945
1946
1947
1948

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


1949
class ChatMessage(OpenAIBaseModel):
1950
    role: str
1951
    content: Optional[str] = None
1952
1953
1954
1955
    refusal: Optional[str] = None
    annotations: Optional[OpenAIAnnotation] = None
    audio: Optional[OpenAIChatCompletionAudio] = None
    function_call: Optional[FunctionCall] = None
1956
    tool_calls: list[ToolCall] = Field(default_factory=list)
1957

1958
1959
1960
    # vLLM-specific fields that are not in OpenAI spec
    reasoning_content: Optional[str] = None

1961

1962
1963
1964
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1965
    bytes: Optional[list[int]] = None
1966
1967
1968


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1969
1970
1971
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
    field_names: ClassVar[Optional[set[str]]] = None
1972
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1973
1974
1975


class ChatCompletionLogProbs(OpenAIBaseModel):
1976
    content: Optional[list[ChatCompletionLogProbsContent]] = None
1977
1978


1979
class ChatCompletionResponseChoice(OpenAIBaseModel):
1980
1981
    index: int
    message: ChatMessage
1982
    logprobs: Optional[ChatCompletionLogProbs] = None
1983
1984
1985
    # per OpenAI spec this is the default
    finish_reason: Optional[str] = "stop"
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1986
    stop_reason: Optional[Union[int, str]] = None
1987
1988
1989
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
    token_ids: Optional[list[int]] = None
1990
1991


1992
class ChatCompletionResponse(OpenAIBaseModel):
1993
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1994
    object: Literal["chat.completion"] = "chat.completion"
1995
1996
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1997
    choices: list[ChatCompletionResponseChoice]
1998
1999
2000
    service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = (
        None
    )
2001
    system_fingerprint: Optional[str] = None
2002
    usage: UsageInfo
2003
2004

    # vLLM-specific fields that are not in OpenAI spec
2005
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
2006
    prompt_token_ids: Optional[list[int]] = None
Robert Shaw's avatar
Robert Shaw committed
2007
    kv_transfer_params: Optional[dict[str, Any]] = Field(
2008
2009
        default=None, description="KVTransfer parameters."
    )
2010
2011


2012
class DeltaMessage(OpenAIBaseModel):
2013
2014
    role: Optional[str] = None
    content: Optional[str] = None
2015
    reasoning_content: Optional[str] = None
2016
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
2017
2018


2019
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
2020
2021
    index: int
    delta: DeltaMessage
2022
    logprobs: Optional[ChatCompletionLogProbs] = None
2023
    finish_reason: Optional[str] = None
2024
    stop_reason: Optional[Union[int, str]] = None
2025
2026
    # not part of the OpenAI spec but for tracing the tokens
    token_ids: Optional[list[int]] = None
2027
2028


2029
class ChatCompletionStreamResponse(OpenAIBaseModel):
2030
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2031
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
2032
2033
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2034
    choices: list[ChatCompletionResponseStreamChoice]
2035
    usage: Optional[UsageInfo] = Field(default=None)
2036
2037
    # not part of the OpenAI spec but for tracing the tokens
    prompt_token_ids: Optional[list[int]] = None
2038
2039


2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


2055
2056
2057
2058
2059
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int


class OutputTokensDetails(OpenAIBaseModel):
2060
2061
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
2062
2063
2064
2065
2066
2067
2068
2069


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
2070
2071
2072
2073
2074
2075


class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
2076
    incomplete_details: Optional[IncompleteDetails] = None
2077
2078
2079
2080
    instructions: Optional[str] = None
    metadata: Optional[Metadata] = None
    model: str
    object: Literal["response"] = "response"
2081
    output: list[ResponseOutputItem]
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
    max_tool_calls: Optional[int] = None
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
    text: Optional[ResponseTextConfig] = None
2096
    top_logprobs: Optional[int] = None
2097
    truncation: Literal["auto", "disabled"]
2098
    usage: Optional[ResponseUsage] = None
2099
2100
    user: Optional[str] = None

2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
    # --8<-- [start:responses-extra-params]
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
    input_messages: Optional[list[ChatCompletionMessageParam]] = None
    output_messages: Optional[list[ChatCompletionMessageParam]] = None
    # --8<-- [end:responses-extra-params]

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
        if msgs:
            serialized = []
            for m in msgs:
                if isinstance(m, dict):
                    serialized.append(m)
                elif hasattr(m, "__dict__"):
                    serialized.append(m.to_dict())
                else:
                    # fallback to pyandic dump
                    serialized.append(m.model_dump_json())
            return serialized
        return None

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
        if msgs:
            serialized = []
            for m in msgs:
                if isinstance(m, dict):
                    serialized.append(m)
                elif hasattr(m, "__dict__"):
                    serialized.append(m.to_dict())
                else:
                    # fallback to pyandic dump
                    serialized.append(m.model_dump_json())
            return serialized
        return None

2144
2145
2146
2147
2148
2149
2150
2151
2152
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
2153
        usage: Optional[ResponseUsage] = None,
2154
2155
        input_messages: Optional[list[ChatCompletionMessageParam]] = None,
        output_messages: Optional[list[ChatCompletionMessageParam]] = None,
2156
    ) -> "ResponsesResponse":
2157
        incomplete_details: Optional[IncompleteDetails] = None
2158
2159
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
2160
2161
2162
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
2163
2164
2165
        return cls(
            id=request.request_id,
            created_at=created_time,
2166
            incomplete_details=incomplete_details,
2167
2168
2169
2170
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
2171
2172
            input_messages=input_messages,
            output_messages=output_messages,
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


2252
StreamingResponsesResponse: TypeAlias = Union[
2253
2254
2255
    "ResponseCreatedEvent",
    "ResponseInProgressEvent",
    "ResponseCompletedEvent",
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseReasoningPartAddedEvent,
    ResponseReasoningPartDoneEvent,
    ResponseCodeInterpreterCallInProgressEvent,
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
    ResponseWebSearchCallCompletedEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseCodeInterpreterCallCompletedEvent,
]

2274
2275
2276
BatchRequestInputBody = Union[
    ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest
]
2277
2278


2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

2298
    # The parameters of the request.
2299
    body: BatchRequestInputBody
2300

2301
    @field_validator("body", mode="plain")
2302
2303
2304
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
2305
        url: str = info.data["url"]
2306
2307
2308
2309
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
2310
        if url.endswith("/score"):
2311
            return ScoreRequest.model_validate(value)
2312
2313
2314
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
2315

2316

2317
2318
2319
2320
2321
2322
2323
2324
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
2325
2326
2327
    body: Optional[
        Union[ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse]
    ] = None
2328
2329


2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

2341
    response: Optional[BatchResponseData]
2342
2343
2344
2345

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Optional[Any]
2346
2347


2348
class TokenizeCompletionRequest(OpenAIBaseModel):
2349
    model: Optional[str] = None
2350
2351
    prompt: str

2352
2353
2354
2355
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
2356
2357
            "the prompt."
        ),
2358
    )
2359
2360
    return_token_strs: Optional[bool] = Field(
        default=False,
2361
2362
2363
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2364
    )
2365
2366
2367


class TokenizeChatRequest(OpenAIBaseModel):
2368
    model: Optional[str] = None
2369
    messages: list[ChatCompletionMessageParam]
2370

2371
2372
    add_generation_prompt: bool = Field(
        default=True,
2373
2374
2375
2376
2377
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
2378
    )
2379
2380
    return_token_strs: Optional[bool] = Field(
        default=False,
2381
2382
2383
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2384
    )
2385
2386
    continue_final_message: bool = Field(
        default=False,
2387
2388
2389
2390
2391
2392
2393
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
2394
2395
2396
2397
2398
2399
2400
2401
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
2402
2403
            "default)."
        ),
2404
2405
2406
2407
2408
2409
2410
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
2411
2412
            "does not define one."
        ),
2413
    )
2414
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
2415
        default=None,
2416
2417
        description=(
            "Additional keyword args to pass to the template renderer. "
2418
2419
            "Will be accessible by the chat template."
        ),
2420
    )
2421
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
2422
2423
2424
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
2425
2426
2427
2428
    tools: Optional[list[ChatCompletionToolsParam]] = Field(
        default=None,
        description=("A list of tools the model may call."),
    )
2429

2430
2431
2432
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
2433
2434
2435
2436
2437
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
2438
2439
        return data

2440
2441

TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
2442
2443
2444
2445
2446


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2447
    tokens: list[int]
2448
    token_strs: Optional[list[str]] = None
2449
2450
2451


class DetokenizeRequest(OpenAIBaseModel):
2452
    model: Optional[str] = None
2453
    tokens: list[int]
2454
2455
2456
2457


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2458
2459


2460
2461
class TokenizerInfoResponse(OpenAIBaseModel):
    """
2462
    Response containing tokenizer configuration
2463
2464
2465
2466
2467
2468
2469
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2470
class LoadLoRAAdapterRequest(BaseModel):
2471
2472
2473
2474
    lora_name: str
    lora_path: str


2475
class UnloadLoRAAdapterRequest(BaseModel):
2476
2477
    lora_name: str
    lora_int_id: Optional[int] = Field(default=None)
2478
2479
2480


## Protocols for Audio
2481
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
2482
2483
2484
2485


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2486
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2487
2488
2489
2490
2491
2492
2493

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2494
    model: Optional[str] = None
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
    """ID of the model to use.
    """

    language: Optional[str] = None
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2522
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2523
2524
        alias="timestamp_granularities[]", default=[]
    )
2525
2526
2527
2528
2529
2530
2531
2532
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2533
    stream: Optional[bool] = False
2534
    """When set, it will enable output to be streamed in a similar fashion
2535
    as the Chat Completion endpoint.
2536
    """
2537
    # --8<-- [start:transcription-extra-params]
2538
2539
2540
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
2541
2542
2543

    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
2544
2545
2546
2547
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2548
    )
2549
    # --8<-- [end:transcription-extra-params]
2550

2551
2552
2553
    to_language: Optional[str] = None
    """The language of the output audio we transcribe to.

2554
    Please note that this is not currently used by supported models at this
2555
2556
2557
    time, but it is a placeholder for future use, matching translation api.
    """

2558
    # --8<-- [start:transcription-sampling-params]
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

    top_p: Optional[float] = None
2569
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2570
2571
2572
2573
2574
2575
2576
    smallest possible set whose cumulative probability exceeds `p`.
    """

    top_k: Optional[int] = None
    """Limits sampling to the `k` most probable tokens at each step."""

    min_p: Optional[float] = None
2577
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
    minimum likelihood threshold during sampling.
    """

    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

    frequency_penalty: Optional[float] = 0.0
    """The frequency penalty to use for sampling."""

    repetition_penalty: Optional[float] = None
    """The repetition penalty to use for sampling."""

    presence_penalty: Optional[float] = 0.0
    """The presence penalty to use for sampling."""
2592
    # --8<-- [end:transcription-sampling-params]
2593

2594
2595
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2596
2597
2598
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2599
        "top_k": 0,
2600
        "min_p": 0.0,
2601
2602
2603
    }

    def to_sampling_params(
2604
2605
        self, default_max_tokens: int, default_sampling_params: Optional[dict] = None
    ) -> SamplingParams:
2606
2607
2608
2609
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2610

2611
2612
2613
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2614
2615
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2616
2617
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2618
2619
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2620
2621
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2622
2623
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2624
2625
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2626
2627
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2628
2629
2630
2631

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2650
2651
2652

    @model_validator(mode="before")
    @classmethod
2653
2654
2655
2656
2657
2658
2659
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2660
2661
2662
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2663
            raise ValueError("Stream options can only be defined when `stream=True`.")
2664
2665

        return data
2666
2667
2668


# Transcription response objects
2669
2670
2671
2672
2673
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2674
2675
2676
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2677
    usage: TranscriptionUsageAudio
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2729
    tokens: list[int]
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2743
    segments: Optional[list[TranscriptionSegment]] = None
2744
2745
    """Segments of the transcribed text and their corresponding details."""

2746
    words: Optional[list[TranscriptionWord]] = None
2747
    """Extracted words and their corresponding timestamps."""
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

    model: Optional[str] = None
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2795
2796
2797
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
    language: Optional[str] = None
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2817
2818
2819
2820
2821
2822
2823
2824
    to_language: Optional[str] = None
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2825
    stream: Optional[bool] = False
2826
    """Custom field not present in the original OpenAI definition. When set,
2827
    it will enable output to be streamed in a similar fashion as the Chat
2828
    Completion endpoint.
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
    """
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2841
2842
        self, default_max_tokens: int, default_sampling_params: Optional[dict] = None
    ) -> SamplingParams:
2843
2844
2845
2846
2847
2848
2849
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2850
2851
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2852

2853
2854
2855
2856
2857
2858
2859
2860
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2861
2862
2863
2864
2865
2866
2867

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2868
            raise ValueError("Stream options can only be defined when `stream=True`.")
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

    segments: Optional[list[TranslationSegment]] = None
    """Segments of the translated text and their corresponding details."""

    words: Optional[list[TranslationWord]] = None
    """Extracted words and their corresponding timestamps."""