"vllm/vscode:/vscode.git/clone" did not exist on "f9060e6b3db426936497f7d3c139fe5504e281f0"
protocol.py 104 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningItem,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
40
from openai.types.responses import (
41
42
43
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
44
from openai.types.responses import (
45
46
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
47
from openai.types.responses.response_reasoning_item import (
48
49
    Content as ResponseReasoningTextContent,
)
50

51
52
53
54
55
56
from vllm.utils.serial_utils import (
    EmbedDType,
    EncodingFormat,
    Endianness,
)

57
58
59
60
# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
61
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
62

63

64
from openai.types.responses.response import IncompleteDetails, ToolChoice
65
66
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
67
68
69
70
71
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
72
    ValidationError,
73
    ValidationInfo,
74
    field_serializer,
75
76
77
    field_validator,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
78

79
from vllm import envs
80
81
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
82
from vllm.logger import init_logger
83
from vllm.logprobs import Logprob
84
from vllm.pooling_params import PoolingParams
85
86
87
88
89
90
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
91
92
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
93

94
95
logger = init_logger(__name__)

96
_LONG_INFO = torch.iinfo(torch.long)
97

Zhuohan Li's avatar
Zhuohan Li committed
98

99
class OpenAIBaseModel(BaseModel):
100
101
102
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

103
    # Cache class field names
104
    field_names: ClassVar[set[str] | None] = None
105

106
    @model_validator(mode="wrap")
107
    @classmethod
108
109
110
111
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
112
113
        field_names = cls.field_names
        if field_names is None:
114
115
116
117
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
118
                if alias := getattr(field, "alias", None):
119
120
121
122
123
124
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
125
                "The following fields were present in the request but ignored: %s",
126
127
                data.keys() - field_names,
            )
128
        return result
129
130


131
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
132
133
    message: str
    type: str
134
    param: str | None = None
135
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
136
137


138
139
140
141
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


142
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
143
144
145
146
147
148
149
150
151
152
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
153
    group: str | None = None
154
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
155
156


157
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
158
159
160
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
161
    owned_by: str = "vllm"
162
163
164
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
165
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
166
167


168
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
169
    object: str = "list"
170
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
171
172


173
class PromptTokenUsageInfo(OpenAIBaseModel):
174
    cached_tokens: int | None = None
175
176


177
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
178
179
    prompt_tokens: int = 0
    total_tokens: int = 0
180
181
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
182
183


184
185
class RequestResponseMetadata(BaseModel):
    request_id: str
186
    final_usage_info: UsageInfo | None = None
187
188


189
190
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
191
    description: str | None = None
192
193
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
194
195
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
196
197


198
class LegacyStructuralTag(OpenAIBaseModel):
199
200
201
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
202
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
203
204
205
    end: str


206
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
207
    type: Literal["structural_tag"]
208
    structures: list[LegacyStructuralTag]
209
210
211
    triggers: list[str]


212
213
214
215
216
217
218
219
220
221
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


222
class ResponseFormat(OpenAIBaseModel):
223
    # type must be "json_schema", "json_object", or "text"
224
    type: Literal["text", "json_object", "json_schema"]
225
    json_schema: JsonSchemaResponseFormat | None = None
226
227


228
229
230
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
231
232


233
class StreamOptions(OpenAIBaseModel):
234
235
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
236
237


238
239
class FunctionDefinition(OpenAIBaseModel):
    name: str
240
241
    description: str | None = None
    parameters: dict[str, Any] | None = None
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


258
259
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
260
261
class LogitsProcessorConstructor(BaseModel):
    qualname: str
262
263
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
264

265
266
    model_config = ConfigDict(extra="forbid")

267

268
LogitsProcessors = list[str | LogitsProcessorConstructor]
269
270


271
def get_logits_processors(
272
273
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
274
275
276
    if processors and pattern:
        logits_processors = []
        for processor in processors:
277
            qualname = processor if isinstance(processor, str) else processor.qualname
278
279
280
281
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
282
283
                    "for more information."
                )
284
285
286
287
288
289
290
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
291
292
293
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
294
295
296
297
298
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
299
            "server. See --logits-processor-pattern engine argument "
300
301
            "for more information."
        )
302
303
304
    return None


305
306
307
ResponseInputOutputItem: TypeAlias = (
    ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
)
308
309


310
311
312
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
313
314
    background: bool | None = False
    include: (
315
316
317
318
319
320
321
322
323
324
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
325
326
327
328
329
330
331
332
333
334
335
336
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
337
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
338
339
340
341
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
342
343
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
344
345
346
347
    top_logprobs: int | None = 0
    top_p: float | None = None
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
348
349
350
351
352
353
354

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
355
356
            "through out the inference process and return in response."
        ),
357
    )
358
    mm_processor_kwargs: dict[str, Any] | None = Field(
359
360
361
362
363
364
365
366
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
367
368
            "if the served model does not use priority scheduling."
        ),
369
    )
370
    cache_salt: str | None = Field(
371
372
373
374
375
376
377
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
378
379
380
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
381
382
383
384
385
386

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
            "response object. Currently only supported for non-streaming "
387
388
389
            "non-background and gpt-oss only. "
        ),
    )
390
391
392
393
394
395
396
397
398
399
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
400
        default_sampling_params: dict | None = None,
401
402
403
404
405
406
407
408
409
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
410
411
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
412
413
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
414
415
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
416
        stop_token_ids = default_sampling_params.get("stop_token_ids")
417
418

        # Structured output
419
        structured_outputs = None
420
421
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
422
423
424
425
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
426
                structured_outputs = StructuredOutputsParams(
427
428
                    json=response_format.schema_
                )
429
430
431
432
433
434
435
436
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
437
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
438
            stop_token_ids=stop_token_ids,
439
440
441
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
442
            structured_outputs=structured_outputs,
443
444
        )

445
446
447
448
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
449
450
451
452
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
453

454
455
456
457
458
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
459
            raise ValueError("background can only be used when `store` is true")
460
461
462
463
464
465
466
467
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

468
469
470
471
472
473
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
474
475
476
477
478
479
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
480
481
        return data

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    @model_validator(mode="before")
    def function_call_parsing(cls, data):
        """Parse function_call dictionaries into ResponseFunctionToolCall objects.
        This ensures Pydantic can properly resolve union types in the input field.
        Function calls provided as dicts are converted to ResponseFunctionToolCall
        objects before validation, while invalid structures are left for Pydantic
        to reject with appropriate error messages.
        """

        input_data = data.get("input")

        # Early return for None, strings, or bytes
        # (strings are iterable but shouldn't be processed)
        if input_data is None or isinstance(input_data, (str, bytes)):
            return data

        # Convert iterators (like ValidatorIterator) to list
        if not isinstance(input_data, list):
            try:
                input_data = list(input_data)
            except TypeError:
                # Not iterable, leave as-is for Pydantic to handle
                return data

        processed_input = []
        for item in input_data:
            if isinstance(item, dict) and item.get("type") == "function_call":
                try:
                    processed_input.append(ResponseFunctionToolCall(**item))
                except ValidationError:
                    # Let Pydantic handle validation for malformed function calls
                    logger.debug(
                        "Failed to parse function_call to ResponseFunctionToolCall, "
                        "leaving for Pydantic validation"
                    )
                    processed_input.append(item)
            else:
                processed_input.append(item)

        data["input"] = processed_input
        return data

524

525
class ChatCompletionRequest(OpenAIBaseModel):
526
527
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
528
    messages: list[ChatCompletionMessageParam]
529
530
531
532
533
534
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
535
        default=None,
536
537
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
538
    )
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
558
    include_reasoning: bool = True
559

560
    # NOTE this will be ignored by vLLM -- the model determines the behavior
561
562
    parallel_tool_calls: bool | None = False
    user: str | None = None
563

564
    # --8<-- [start:chat-completion-sampling-params]
565
    best_of: int | None = None
566
    use_beam_search: bool = False
567
568
569
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
570
    length_penalty: float = 1.0
571
    stop_token_ids: list[int] | None = []
572
573
574
575
576
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
577
578
579
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
580
    bad_words: list[str] = Field(default_factory=list)
581
    # --8<-- [end:chat-completion-sampling-params]
582

583
    # --8<-- [start:chat-completion-extra-params]
584
    echo: bool = Field(
585
586
587
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
588
589
            "if they belong to the same role."
        ),
590
    )
591
    add_generation_prompt: bool = Field(
592
        default=True,
593
594
595
596
597
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
598
    )
599
600
    continue_final_message: bool = Field(
        default=False,
601
602
603
604
605
606
607
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
608
    )
609
    add_special_tokens: bool = Field(
610
611
612
613
614
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
615
            "special tokens so this should be set to false (as is the "
616
617
            "default)."
        ),
618
    )
619
    documents: list[dict[str, str]] | None = Field(
620
        default=None,
621
622
623
624
625
626
627
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
628
    )
629
    chat_template: str | None = Field(
630
631
632
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
633
634
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
635
636
            "does not define one."
        ),
637
    )
638
    chat_template_kwargs: dict[str, Any] | None = Field(
639
        default=None,
640
641
        description=(
            "Additional keyword args to pass to the template renderer. "
642
643
            "Will be accessible by the chat template."
        ),
644
    )
645
    mm_processor_kwargs: dict[str, Any] | None = Field(
646
647
648
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
649
    structured_outputs: StructuredOutputsParams | None = Field(
650
        default=None,
651
        description="Additional kwargs for structured outputs",
652
    )
653
    guided_json: str | dict | BaseModel | None = Field(
654
655
656
657
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
658
659
            "Please pass `json` to `structured_outputs` instead."
        ),
660
    )
661
    guided_regex: str | None = Field(
662
663
664
665
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
666
667
            "Please pass `regex` to `structured_outputs` instead."
        ),
668
    )
669
    guided_choice: list[str] | None = Field(
670
671
672
673
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
674
675
            "Please pass `choice` to `structured_outputs` instead."
        ),
676
    )
677
    guided_grammar: str | None = Field(
678
679
680
681
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
682
683
            "Please pass `grammar` to `structured_outputs` instead."
        ),
684
    )
685
    structural_tag: str | None = Field(
686
687
688
689
        default=None,
        description=(
            "`structural_tag` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
690
691
            "Please pass `structural_tag` to `structured_outputs` instead."
        ),
692
    )
693
    guided_decoding_backend: str | None = Field(
694
695
696
697
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
698
699
            "Please remove it from your request."
        ),
700
    )
701
    guided_whitespace_pattern: str | None = Field(
702
703
704
705
706
707
708
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
709
710
711
712
713
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
714
715
            "if the served model does not use priority scheduling."
        ),
716
    )
717
718
719
720
721
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
722
723
            "through out the inference process and return in response."
        ),
724
    )
725
    logits_processors: LogitsProcessors | None = Field(
726
727
728
729
730
731
732
733
734
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
735
736
737
            "{'param': 'value'}}."
        ),
    )
738
    return_tokens_as_token_ids: bool | None = Field(
739
740
741
742
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
743
744
745
            "that are not JSON-encodable can be identified."
        ),
    )
746
    return_token_ids: bool | None = Field(
747
748
749
750
751
752
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
753
754
755
            "need to map generated text back to input tokens."
        ),
    )
756
    cache_salt: str | None = Field(
757
758
759
760
761
762
763
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
764
765
766
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
767
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
768
        default=None,
769
770
        description="KVTransfer parameters used for disaggregated serving.",
    )
771

772
    vllm_xargs: dict[str, str | int | float] | None = Field(
773
        default=None,
774
775
776
777
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
778
779
    )

780
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
781

782
783
784
785
786
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
787
        "top_k": 0,
788
789
790
791
        "min_p": 0.0,
    }

    def to_beam_search_params(
792
793
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
794
        n = self.n if self.n is not None else 1
795
796
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
797
798
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
799
800
801
802
803
804

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
805
            length_penalty=self.length_penalty,
806
807
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
808

809
    def to_sampling_params(
810
        self,
811
        max_tokens: int,
812
        logits_processor_pattern: str | None,
813
        default_sampling_params: dict,
814
    ) -> SamplingParams:
815
816
817
818
819
820
821
822
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
823
824
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
825
826
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
827
828
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
829
830
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
831
832
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
833
834
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
835
836
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
837

838
839
840
841
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

842
843
844
845
846
847
848
849
850
851
852
853
854
855
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
                structural_tag=self.structural_tag,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

856
        response_format = self.response_format
857
        if response_format is not None:
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format is not None:
                if response_format.type == "json_object":
                    self.structured_outputs.json_object = True
                elif response_format.type == "json_schema":
                    json_schema = response_format.json_schema
                    assert json_schema is not None
                    self.structured_outputs.json = json_schema.json_schema
                elif response_format.type == "structural_tag":
                    structural_tag = response_format
                    assert structural_tag is not None and isinstance(
874
875
876
877
878
                        structural_tag,
                        (
                            LegacyStructuralTagResponseFormat,
                            StructuralTagResponseFormat,
                        ),
879
                    )
880
                    s_tag_obj = structural_tag.model_dump(by_alias=True)
881
                    self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
882

883
884
885
886
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
887
        return SamplingParams.from_optional(
888
            n=self.n,
889
            best_of=self.best_of,
890
891
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
892
893
894
895
896
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
897
            seed=self.seed,
898
899
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
900
            logprobs=self.top_logprobs if self.logprobs else None,
901
            prompt_logprobs=prompt_logprobs,
902
            ignore_eos=self.ignore_eos,
903
            max_tokens=max_tokens,
904
            min_tokens=self.min_tokens,
905
906
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
907
908
909
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
910
            include_stop_str_in_output=self.include_stop_str_in_output,
911
            truncate_prompt_tokens=self.truncate_prompt_tokens,
912
913
914
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
915
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
916
            logit_bias=self.logit_bias,
917
            bad_words=self.bad_words,
918
            allowed_token_ids=self.allowed_token_ids,
919
920
            extra_args=extra_args or None,
        )
921

922
    @model_validator(mode="before")
923
    @classmethod
924
925
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
926
            raise ValueError("Stream options can only be defined when `stream=True`.")
927
928
929
930
931
932
933

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
934
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
935
                raise ValueError(
936
937
                    "`prompt_logprobs` are not available when `stream=True`."
                )
938

939
            if prompt_logprobs < 0 and prompt_logprobs != -1:
940
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
941
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
942
943
944
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
945
        if (top_logprobs := data.get("top_logprobs")) is not None:
946
            if top_logprobs < 0 and top_logprobs != -1:
947
                raise ValueError("`top_logprobs` must be a positive value or -1.")
948

949
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
950
951
952
953
954
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
955

956
957
    @model_validator(mode="before")
    @classmethod
958
    def check_structured_outputs_count(cls, data):
959
960
961
        if isinstance(data, ValueError):
            raise data

962
        if data.get("structured_outputs", None) is None:
963
964
            return data

965
        structured_outputs_kwargs = data["structured_outputs"]
966
967
        count = sum(
            structured_outputs_kwargs.get(k) is not None
968
969
            for k in ("json", "regex", "choice")
        )
970
971
        # you can only use one kind of constraints for structured outputs
        if count > 1:
972
            raise ValueError(
973
                "You can only use one kind of constraints for structured "
974
975
                "outputs ('json', 'regex' or 'choice')."
            )
976
977
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
978
979
980
            "none",
            "auto",
            "required",
981
        ):
982
            raise ValueError(
983
                "You can only either use constraints for structured outputs "
984
985
                "or tools, not both."
            )
986
987
988
989
        return data

    @model_validator(mode="before")
    @classmethod
990
991
992
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
993
        if "tool_choice" not in data and data.get("tools"):
994
995
            data["tool_choice"] = "auto"

996
        # if "tool_choice" is "none" -- no validation is needed for tools
997
998
999
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

1000
        # if "tool_choice" is specified -- validation
1001
        if "tool_choice" in data and data["tool_choice"] is not None:
1002
            # ensure that if "tool choice" is specified, tools are present
1003
            if "tools" not in data or data["tools"] is None:
1004
                raise ValueError("When using `tool_choice`, `tools` must be set.")
1005
1006

            # make sure that tool choice is either a named tool
1007
            # OR that it's set to "auto" or "required"
1008
1009
1010
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
1011
                raise ValueError(
1012
1013
1014
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
1015
                )
1016

1017
1018
1019
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
1020
1021
1022
1023
1024
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
1025
1026
1027
1028
                data["tool_choice"] = "none"
                del data["tools"]
                return data

1029
1030
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
1031
1032
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
1033
                ' "function": {"name": "my_function"}}`'
1034
            )
1035
1036
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
1037
1038
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
1039
                    raise ValueError(
1040
                        f"Invalid value for `function`: `{function}` in "
1041
1042
                        f"`tool_choice`! {correct_usage_message}"
                    )
1043
                if "name" not in function:
1044
1045
1046
1047
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
1048
                function_name = function["name"]
1049
                if not isinstance(function_name, str) or len(function_name) == 0:
1050
                    raise ValueError(
1051
                        f"Invalid `name` in `function`: `{function_name}`"
1052
1053
                        f" in `tool_choice`! {correct_usage_message}"
                    )
1054
                for tool in data["tools"]:
1055
                    if tool["function"]["name"] == function_name:
1056
1057
1058
1059
1060
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1061
1062
                        " of the specified `tools`"
                    )
1063
1064
        return data

1065
1066
1067
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1068
1069
1070
1071
1072
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1073
1074
        return data

1075
1076
1077
1078
1079
1080
1081
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1082
1083
1084
1085
1086
1087
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1088
1089
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1090

1091
class CompletionRequest(OpenAIBaseModel):
1092
1093
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1094
1095
1096
1097
1098
1099
1100
1101
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    best_of: int | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1102
    n: int = 1
1103
1104
1105
1106
1107
1108
1109
1110
1111
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1112

1113
    # --8<-- [start:completion-sampling-params]
1114
    use_beam_search: bool = False
1115
1116
1117
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1118
    length_penalty: float = 1.0
1119
    stop_token_ids: list[int] | None = []
1120
1121
1122
1123
1124
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1125
1126
1127
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1128
    # --8<-- [end:completion-sampling-params]
1129

1130
    # --8<-- [start:completion-extra-params]
1131
    prompt_embeds: bytes | list[bytes] | None = None
1132
1133
    add_special_tokens: bool = Field(
        default=True,
1134
        description=(
1135
            "If true (the default), special tokens (e.g. BOS) will be added to "
1136
1137
            "the prompt."
        ),
1138
    )
1139
    response_format: AnyResponseFormat | None = Field(
1140
        default=None,
1141
1142
1143
1144
1145
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1146
    )
1147
    structured_outputs: StructuredOutputsParams | None = Field(
1148
        default=None,
1149
        description="Additional kwargs for structured outputs",
1150
    )
1151
    guided_json: str | dict | BaseModel | None = Field(
1152
1153
1154
1155
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1156
1157
            "Please pass `json` to `structured_outputs` instead."
        ),
1158
    )
1159
    guided_regex: str | None = Field(
1160
1161
1162
1163
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1164
1165
            "Please pass `regex` to `structured_outputs` instead."
        ),
1166
    )
1167
    guided_choice: list[str] | None = Field(
1168
1169
1170
1171
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1172
1173
            "Please pass `choice` to `structured_outputs` instead."
        ),
1174
    )
1175
    guided_grammar: str | None = Field(
1176
1177
1178
1179
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1180
1181
            "Please pass `grammar` to `structured_outputs` instead."
        ),
1182
    )
1183
1184
1185
1186
    structural_tag: str | None = Field(
        default=None,
        description=("If specified, the output will follow the structural tag schema."),
    )
1187
    guided_decoding_backend: str | None = Field(
1188
1189
1190
1191
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1192
1193
            "Please remove it from your request."
        ),
1194
    )
1195
    guided_whitespace_pattern: str | None = Field(
1196
1197
1198
1199
1200
1201
1202
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
1203
1204
1205
1206
1207
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1208
1209
            "if the served model does not use priority scheduling."
        ),
1210
    )
1211
1212
1213
1214
1215
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1216
1217
            "through out the inference process and return in response."
        ),
1218
    )
1219
    logits_processors: LogitsProcessors | None = Field(
1220
1221
1222
1223
1224
1225
1226
1227
1228
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1229
1230
1231
            "{'param': 'value'}}."
        ),
    )
1232

1233
    return_tokens_as_token_ids: bool | None = Field(
1234
1235
1236
1237
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1238
1239
1240
            "that are not JSON-encodable can be identified."
        ),
    )
1241
    return_token_ids: bool | None = Field(
1242
1243
1244
1245
1246
1247
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1248
1249
1250
            "need to map generated text back to input tokens."
        ),
    )
1251

1252
    cache_salt: str | None = Field(
1253
1254
1255
1256
1257
1258
1259
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1260
1261
1262
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
1263

1264
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1265
        default=None,
1266
1267
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1268

1269
    vllm_xargs: dict[str, str | int | float] | None = Field(
1270
        default=None,
1271
1272
1273
1274
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1275
1276
    )

1277
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1278

1279
1280
1281
1282
1283
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1284
        "top_k": 0,
1285
1286
1287
1288
        "min_p": 0.0,
    }

    def to_beam_search_params(
1289
1290
        self,
        max_tokens: int,
1291
        default_sampling_params: dict | None = None,
1292
1293
1294
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1295
        n = self.n if self.n is not None else 1
1296
1297
1298

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1299
1300
1301
1302
1303
1304

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1305
            length_penalty=self.length_penalty,
1306
1307
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1308

1309
    def to_sampling_params(
1310
        self,
1311
        max_tokens: int,
1312
1313
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1314
    ) -> SamplingParams:
1315
1316
        if default_sampling_params is None:
            default_sampling_params = {}
1317

1318
1319
1320
1321
1322
1323
1324
1325
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1326
1327
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1328
1329
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1330
1331
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1332
1333
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1334
1335
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1336
1337
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1338
1339
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1340

1341
1342
1343
1344
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1345
1346
        echo_without_generation = self.echo and self.max_tokens == 0

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
        guided_json_object = None
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
            elif self.response_format.type == "structural_tag":
                structural_tag = self.response_format
                assert structural_tag is not None and isinstance(
                    structural_tag, StructuralTagResponseFormat
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structural_tag = json.dumps(s_tag_obj)

1363
1364
1365
1366
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
1367
                json_object=guided_json_object,
1368
1369
1370
1371
1372
1373
1374
1375
1376
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

1377
1378
1379
1380
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1381
        return SamplingParams.from_optional(
1382
            n=self.n,
1383
            best_of=self.best_of,
1384
1385
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1386
1387
1388
1389
1390
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1391
            seed=self.seed,
1392
1393
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1394
            logprobs=self.logprobs,
1395
            ignore_eos=self.ignore_eos,
1396
            max_tokens=max_tokens if not echo_without_generation else 1,
1397
            min_tokens=self.min_tokens,
1398
            prompt_logprobs=prompt_logprobs,
1399
            skip_special_tokens=self.skip_special_tokens,
1400
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1401
            include_stop_str_in_output=self.include_stop_str_in_output,
1402
1403
1404
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1405
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1406
1407
1408
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1409
            structured_outputs=self.structured_outputs,
1410
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1411
            allowed_token_ids=self.allowed_token_ids,
1412
            extra_args=extra_args or None,
1413
        )
1414

1415
1416
    @model_validator(mode="before")
    @classmethod
1417
    def check_structured_outputs_count(cls, data):
1418
        if data.get("structured_outputs", None) is None:
1419
1420
            return data

1421
        structured_outputs_kwargs = data["structured_outputs"]
1422
1423
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1424
1425
            for k in ("json", "regex", "choice")
        )
1426
        if count > 1:
1427
            raise ValueError(
1428
                "You can only use one kind of constraints for structured "
1429
1430
                "outputs ('json', 'regex' or 'choice')."
            )
1431
1432
        return data

1433
1434
1435
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1436
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1437
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1438
                raise ValueError(
1439
1440
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1441

1442
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1443
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1444
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
1445
1446
1447
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
1448
1449
1450
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1451
1452
        return data

1453
1454
1455
1456
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1457
            raise ValueError("Stream options can only be defined when `stream=True`.")
1458

1459
1460
        return data

1461
1462
1463
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1464
1465
1466
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1467
1468
1469
1470
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1471
1472

        if prompt_is_empty and embeds_is_empty:
1473
            raise ValueError(
1474
1475
1476
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1477
1478
        return data

1479
1480
1481
1482
1483
1484
1485
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1486
1487
1488
1489
1490
1491
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1492
1493
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1494

1495
class EmbeddingCompletionRequest(OpenAIBaseModel):
1496
1497
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1498
1499
    model: str | None = None
    input: list[int] | list[list[int]] | str | list[str]
1500
    encoding_format: EncodingFormat = "float"
1501
1502
1503
    dimensions: int | None = None
    user: str | None = None
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1504

1505
    # --8<-- [start:embedding-extra-params]
1506
1507
1508
1509
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1510
1511
            "the prompt."
        ),
1512
    )
1513
1514
1515
1516
1517
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1518
1519
            "if the served model does not use priority scheduling."
        ),
1520
    )
1521
1522
1523
1524
1525
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1526
1527
            "through out the inference process and return in response."
        ),
1528
    )
1529
1530
1531
1532
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
1533
    embed_dtype: EmbedDType = Field(
1534
1535
        default="float32",
        description=(
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
1547
1548
        ),
    )
1549
    # --8<-- [end:embedding-extra-params]
1550

1551
    def to_pooling_params(self):
1552
1553
1554
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1555
1556
            normalize=self.normalize,
        )
1557
1558


1559
class EmbeddingChatRequest(OpenAIBaseModel):
1560
    model: str | None = None
1561
    messages: list[ChatCompletionMessageParam]
1562

1563
    encoding_format: EncodingFormat = "float"
1564
1565
1566
    dimensions: int | None = None
    user: str | None = None
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1567

1568
    # --8<-- [start:chat-embedding-extra-params]
1569
1570
    add_generation_prompt: bool = Field(
        default=False,
1571
1572
1573
1574
1575
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1576
1577
    )

1578
1579
1580
1581
1582
1583
1584
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1585
1586
            "default)."
        ),
1587
    )
1588
    chat_template: str | None = Field(
1589
1590
1591
1592
1593
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1594
1595
            "does not define one."
        ),
1596
    )
1597
    chat_template_kwargs: dict[str, Any] | None = Field(
1598
        default=None,
1599
1600
        description=(
            "Additional keyword args to pass to the template renderer. "
1601
1602
            "Will be accessible by the chat template."
        ),
1603
    )
1604
    mm_processor_kwargs: dict[str, Any] | None = Field(
1605
1606
1607
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1608
1609
1610
1611
1612
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1613
1614
            "if the served model does not use priority scheduling."
        ),
1615
    )
1616
1617
1618
1619
1620
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1621
1622
            "through out the inference process and return in response."
        ),
1623
    )
1624
1625
1626
1627
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
1628
    embed_dtype: EmbedDType = Field(
1629
1630
        default="float32",
        description=(
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
1642
1643
        ),
    )
1644
    # --8<-- [end:chat-embedding-extra-params]
1645
1646
1647
1648

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1649
1650
1651
1652
1653
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1654
1655
1656
        return data

    def to_pooling_params(self):
1657
1658
1659
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1660
1661
            normalize=self.normalize,
        )
1662
1663


1664
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
1665

1666
1667
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
1668
1669
1670
1671
1672

T = TypeVar("T")


class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
1673
    model: str | None = None
1674
1675
1676
1677
1678
1679
1680
1681
1682

    priority: int = Field(default=0)
    """
    The priority of the request (lower means earlier handling;
    default: 0). Any priority other than 0 will raise an error
    if the served model does not use priority scheduling.
    """
    data: T

1683
1684
    encoding_format: EncodingFormat = "float"
    embed_dtype: EmbedDType = Field(
1685
1686
        default="float32",
        description=(
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
1698
1699
1700
        ),
    )

1701
    def to_pooling_params(self):
1702
        return PoolingParams()
1703
1704
1705


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
1706
    request_id: str | None = None
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
    """
    The request_id associated with this response
    """
    created_at: int = Field(default_factory=lambda: int(time.time()))

    data: T
    """
    When using plugins IOProcessor plugins, the actual output is generated
    by the plugin itself. Hence, we use a generic type for the response data
    """


1719
1720
1721
PoolingRequest: TypeAlias = (
    PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
)
1722

1723

1724
class ScoreRequest(OpenAIBaseModel):
1725
1726
1727
1728
    model: str | None = None
    text_1: list[str] | str | ScoreMultiModalParam
    text_2: list[str] | str | ScoreMultiModalParam
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1729

1730
    # --8<-- [start:score-extra-params]
1731

1732
    mm_processor_kwargs: dict[str, Any] | None = Field(
1733
1734
1735
1736
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1737
1738
1739
1740
1741
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1742
1743
            "if the served model does not use priority scheduling."
        ),
1744
    )
1745

1746
    activation: bool | None = None
1747

1748
    # --8<-- [end:score-extra-params]
1749

1750
    def to_pooling_params(self):
1751
1752
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1753
1754
            activation=self.activation,
        )
1755
1756


1757
class RerankRequest(OpenAIBaseModel):
1758
1759
1760
    model: str | None = None
    query: str | ScoreMultiModalParam
    documents: list[str] | ScoreMultiModalParam
1761
    top_n: int = Field(default_factory=lambda: 0)
1762
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1763

1764
    # --8<-- [start:rerank-extra-params]
1765

1766
    mm_processor_kwargs: dict[str, Any] | None = Field(
1767
1768
1769
1770
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1771
1772
1773
1774
1775
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1776
1777
            "if the served model does not use priority scheduling."
        ),
1778
    )
1779

1780
    activation: bool | None = None
1781

1782
    # --8<-- [end:rerank-extra-params]
1783

1784
    def to_pooling_params(self):
1785
1786
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1787
1788
            activation=self.activation,
        )
1789
1790
1791


class RerankDocument(BaseModel):
1792
1793
    text: str | None = None
    multi_modal: ScoreContentPartParam | None = None
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1810
    results: list[RerankResult]
1811
1812


1813
class CompletionLogProbs(OpenAIBaseModel):
1814
    text_offset: list[int] = Field(default_factory=list)
1815
    token_logprobs: list[float | None] = Field(default_factory=list)
1816
    tokens: list[str] = Field(default_factory=list)
1817
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1818
1819


1820
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1821
1822
    index: int
    text: str
1823
1824
1825
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1826
1827
1828
1829
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1830
1831
            "including encountering the EOS token"
        ),
1832
    )
1833
1834
1835
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1836
1837


1838
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1839
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1840
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1841
1842
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1843
    choices: list[CompletionResponseChoice]
1844
1845
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1846
    usage: UsageInfo
1847
1848

    # vLLM-specific fields that are not in OpenAI spec
1849
    kv_transfer_params: dict[str, Any] | None = Field(
1850
1851
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1852
1853


1854
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1855
1856
    index: int
    text: str
1857
1858
1859
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1860
1861
1862
1863
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1864
1865
            "including encountering the EOS token"
        ),
1866
    )
1867
1868
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1869
1870
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1871
1872


1873
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1874
1875
1876
1877
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1878
    choices: list[CompletionResponseStreamChoice]
1879
    usage: UsageInfo | None = Field(default=None)
1880
1881


1882
class EmbeddingResponseData(OpenAIBaseModel):
1883
1884
    index: int
    object: str = "embedding"
1885
    embedding: list[float] | str
1886
1887


1888
class EmbeddingResponse(OpenAIBaseModel):
1889
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1890
1891
1892
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1893
    data: list[EmbeddingResponseData]
1894
1895
1896
    usage: UsageInfo


1897
1898
1899
1900
1901
1902
class EmbeddingBytesResponse(OpenAIBaseModel):
    body: list[bytes]
    metadata: str
    media_type: str = "application/octet-stream"


1903
1904
1905
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1906
    data: list[list[float]] | list[float] | str
1907
1908
1909
1910
1911
1912
1913


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1914
    data: list[PoolingResponseData]
1915
1916
1917
    usage: UsageInfo


1918
1919
1920
1921
1922
1923
class PoolingBytesResponse(OpenAIBaseModel):
    body: list[bytes]
    metadata: str
    media_type: str = "application/octet-stream"


1924
1925
1926
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1927
    score: float
1928
1929
1930
1931
1932
1933
1934


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1935
    data: list[ScoreResponseData]
1936
1937
1938
    usage: UsageInfo


1939
class ClassificationRequest(OpenAIBaseModel):
1940
1941
1942
1943
    model: str | None = None
    input: list[str] | str
    truncate_prompt_tokens: int | None = None
    user: str | None = None
1944

1945
    # --8<-- [start:classification-extra-params]
1946
1947
1948
1949
1950
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1951
1952
            "if the served model does not use priority scheduling."
        ),
1953
1954
    )

1955
    activation: bool | None = None
1956

1957
    # --8<-- [end:classification-extra-params]
1958
1959

    def to_pooling_params(self):
1960
1961
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1962
1963
            activation=self.activation,
        )
1964
1965
1966
1967


class ClassificationData(OpenAIBaseModel):
    index: int
1968
    label: str | None
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1982
1983
1984
1985
1986
1987
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1988
    id: str = Field(default_factory=make_tool_call_id)
1989
1990
1991
1992
    type: Literal["function"] = "function"
    function: FunctionCall


1993
class DeltaFunctionCall(BaseModel):
1994
1995
    name: str | None = None
    arguments: str | None = None
1996
1997
1998
1999


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
2000
2001
    id: str | None = None
    type: Literal["function"] | None = None
2002
    index: int
2003
    function: DeltaFunctionCall | None = None
2004
2005
2006
2007
2008
2009
2010


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
2011
    tool_calls: list[ToolCall]
2012
2013
2014

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
2015
    content: str | None = None
2016
2017


2018
class ChatMessage(OpenAIBaseModel):
2019
    role: str
2020
2021
2022
2023
2024
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
2025
    tool_calls: list[ToolCall] = Field(default_factory=list)
2026

2027
    # vLLM-specific fields that are not in OpenAI spec
2028
    reasoning_content: str | None = None
2029

2030

2031
2032
2033
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
2034
    bytes: list[int] | None = None
2035
2036
2037


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
2038
2039
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
2040
    field_names: ClassVar[set[str] | None] = None
2041
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
2042
2043
2044


class ChatCompletionLogProbs(OpenAIBaseModel):
2045
    content: list[ChatCompletionLogProbsContent] | None = None
2046
2047


2048
class ChatCompletionResponseChoice(OpenAIBaseModel):
2049
2050
    index: int
    message: ChatMessage
2051
    logprobs: ChatCompletionLogProbs | None = None
2052
    # per OpenAI spec this is the default
2053
    finish_reason: str | None = "stop"
2054
    # not part of the OpenAI spec but included in vLLM for legacy reasons
2055
    stop_reason: int | str | None = None
2056
2057
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
2058
    token_ids: list[int] | None = None
2059
2060


2061
class ChatCompletionResponse(OpenAIBaseModel):
2062
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2063
    object: Literal["chat.completion"] = "chat.completion"
2064
2065
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2066
    choices: list[ChatCompletionResponseChoice]
2067
2068
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
2069
    usage: UsageInfo
2070
2071

    # vLLM-specific fields that are not in OpenAI spec
2072
2073
2074
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
2075
2076
        default=None, description="KVTransfer parameters."
    )
2077
2078


2079
class DeltaMessage(OpenAIBaseModel):
2080
2081
2082
    role: str | None = None
    content: str | None = None
    reasoning_content: str | None = None
2083
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
2084
2085


2086
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
2087
2088
    index: int
    delta: DeltaMessage
2089
2090
2091
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2092
    # not part of the OpenAI spec but for tracing the tokens
2093
    token_ids: list[int] | None = None
2094
2095


2096
class ChatCompletionStreamResponse(OpenAIBaseModel):
2097
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2098
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
2099
2100
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2101
    choices: list[ChatCompletionResponseStreamChoice]
2102
    usage: UsageInfo | None = Field(default=None)
2103
    # not part of the OpenAI spec but for tracing the tokens
2104
    prompt_token_ids: list[int] | None = None
2105
2106


2107
2108
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2109
2110
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2111
2112
2113
2114
2115
2116
2117
2118


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
2119
    usage: UsageInfo | None = Field(default=None)
2120
2121


2122
2123
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
2124
2125
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
2126
2127
2128


class OutputTokensDetails(OpenAIBaseModel):
2129
2130
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
2131
2132
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
2133
2134
2135
2136
2137
2138
2139
2140


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
2141
2142


2143
2144
2145
2146
2147
2148
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
2149
    elif hasattr(msg, "to_dict"):
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


2163
2164
2165
2166
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
2167
2168
2169
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
2170
2171
    model: str
    object: Literal["response"] = "response"
2172
    output: list[ResponseOutputItem]
2173
2174
2175
2176
2177
2178
2179
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
2180
2181
2182
2183
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
2184
2185
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
2186
2187
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
2188
    truncation: Literal["auto", "disabled"]
2189
2190
    usage: ResponseUsage | None = None
    user: str | None = None
2191

2192
2193
2194
2195
    # --8<-- [start:responses-extra-params]
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
2196
2197
    input_messages: list[ChatCompletionMessageParam] | None = None
    output_messages: list[ChatCompletionMessageParam] | None = None
2198
2199
2200
2201
2202
2203
2204
    # --8<-- [end:responses-extra-params]

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
2205
        return serialize_messages(msgs)
2206
2207
2208
2209
2210

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
2211
        return serialize_messages(msgs)
2212

2213
2214
2215
2216
2217
2218
2219
2220
2221
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
2222
2223
2224
        usage: ResponseUsage | None = None,
        input_messages: list[ChatCompletionMessageParam] | None = None,
        output_messages: list[ChatCompletionMessageParam] | None = None,
2225
    ) -> "ResponsesResponse":
2226
        incomplete_details: IncompleteDetails | None = None
2227
2228
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
2229
2230
2231
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
2232
2233
2234
        return cls(
            id=request.request_id,
            created_at=created_time,
2235
            incomplete_details=incomplete_details,
2236
2237
2238
2239
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
2240
2241
            input_messages=input_messages,
            output_messages=output_messages,
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
)
2342

2343
2344
2345
BatchRequestInputBody: TypeAlias = (
    ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)
2346
2347


2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

2367
    # The parameters of the request.
2368
    body: BatchRequestInputBody
2369

2370
    @field_validator("body", mode="plain")
2371
2372
2373
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
2374
        url: str = info.data["url"]
2375
2376
2377
2378
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
2379
        if url.endswith("/score"):
2380
            return ScoreRequest.model_validate(value)
2381
2382
2383
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
2384

2385

2386
2387
2388
2389
2390
2391
2392
2393
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
2394
2395
2396
2397
2398
2399
2400
    body: (
        ChatCompletionResponse
        | EmbeddingResponse
        | ScoreResponse
        | RerankResponse
        | None
    ) = None
2401
2402


2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

2414
    response: BatchResponseData | None
2415
2416
2417

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
2418
    error: Any | None
2419
2420


2421
class TokenizeCompletionRequest(OpenAIBaseModel):
2422
    model: str | None = None
2423
2424
    prompt: str

2425
2426
2427
2428
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
2429
2430
            "the prompt."
        ),
2431
    )
2432
    return_token_strs: bool | None = Field(
2433
        default=False,
2434
2435
2436
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2437
    )
2438
2439
2440


class TokenizeChatRequest(OpenAIBaseModel):
2441
    model: str | None = None
2442
    messages: list[ChatCompletionMessageParam]
2443

2444
2445
    add_generation_prompt: bool = Field(
        default=True,
2446
2447
2448
2449
2450
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
2451
    )
2452
    return_token_strs: bool | None = Field(
2453
        default=False,
2454
2455
2456
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2457
    )
2458
2459
    continue_final_message: bool = Field(
        default=False,
2460
2461
2462
2463
2464
2465
2466
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
2467
2468
2469
2470
2471
2472
2473
2474
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
2475
2476
            "default)."
        ),
2477
    )
2478
    chat_template: str | None = Field(
2479
2480
2481
2482
2483
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
2484
2485
            "does not define one."
        ),
2486
    )
2487
    chat_template_kwargs: dict[str, Any] | None = Field(
2488
        default=None,
2489
2490
        description=(
            "Additional keyword args to pass to the template renderer. "
2491
2492
            "Will be accessible by the chat template."
        ),
2493
    )
2494
    mm_processor_kwargs: dict[str, Any] | None = Field(
2495
2496
2497
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
2498
    tools: list[ChatCompletionToolsParam] | None = Field(
2499
2500
2501
        default=None,
        description=("A list of tools the model may call."),
    )
2502

2503
2504
2505
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
2506
2507
2508
2509
2510
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
2511
2512
        return data

2513

2514
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
2515
2516
2517
2518
2519


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2520
    tokens: list[int]
2521
    token_strs: list[str] | None = None
2522
2523
2524


class DetokenizeRequest(OpenAIBaseModel):
2525
    model: str | None = None
2526
    tokens: list[int]
2527
2528
2529
2530


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2531
2532


2533
2534
class TokenizerInfoResponse(OpenAIBaseModel):
    """
2535
    Response containing tokenizer configuration
2536
2537
2538
2539
2540
2541
2542
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2543
class LoadLoRAAdapterRequest(BaseModel):
2544
2545
2546
2547
    lora_name: str
    lora_path: str


2548
class UnloadLoRAAdapterRequest(BaseModel):
2549
    lora_name: str
2550
    lora_int_id: int | None = Field(default=None)
2551
2552
2553


## Protocols for Audio
2554
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
2555
2556
2557
2558


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2559
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2560
2561
2562
2563
2564
2565
2566

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2567
    model: str | None = None
2568
2569
2570
    """ID of the model to use.
    """

2571
    language: str | None = None
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2595
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2596
2597
        alias="timestamp_granularities[]", default=[]
    )
2598
2599
2600
2601
2602
2603
2604
2605
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2606
    stream: bool | None = False
2607
    """When set, it will enable output to be streamed in a similar fashion
2608
    as the Chat Completion endpoint.
2609
    """
2610
    # --8<-- [start:transcription-extra-params]
2611
    # Flattened stream option to simplify form data.
2612
2613
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2614

2615
    vllm_xargs: dict[str, str | int | float] | None = Field(
2616
        default=None,
2617
2618
2619
2620
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2621
    )
2622
    # --8<-- [end:transcription-extra-params]
2623

2624
    to_language: str | None = None
2625
2626
    """The language of the output audio we transcribe to.

2627
    Please note that this is not currently used by supported models at this
2628
2629
2630
    time, but it is a placeholder for future use, matching translation api.
    """

2631
    # --8<-- [start:transcription-sampling-params]
2632
2633
2634
2635
2636
2637
2638
2639
2640
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2641
    top_p: float | None = None
2642
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2643
2644
2645
    smallest possible set whose cumulative probability exceeds `p`.
    """

2646
    top_k: int | None = None
2647
2648
    """Limits sampling to the `k` most probable tokens at each step."""

2649
    min_p: float | None = None
2650
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2651
2652
2653
    minimum likelihood threshold during sampling.
    """

2654
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2655
2656
    """The seed to use for sampling."""

2657
    frequency_penalty: float | None = 0.0
2658
2659
    """The frequency penalty to use for sampling."""

2660
    repetition_penalty: float | None = None
2661
2662
    """The repetition penalty to use for sampling."""

2663
    presence_penalty: float | None = 0.0
2664
    """The presence penalty to use for sampling."""
2665
    # --8<-- [end:transcription-sampling-params]
2666

2667
2668
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2669
2670
2671
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2672
        "top_k": 0,
2673
        "min_p": 0.0,
2674
2675
2676
    }

    def to_sampling_params(
2677
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2678
    ) -> SamplingParams:
2679
2680
2681
2682
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2683

2684
2685
2686
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2687
2688
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2689
2690
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2691
2692
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2693
2694
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2695
2696
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2697
2698
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2699
2700
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2701
2702
2703
2704

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2723
2724
2725

    @model_validator(mode="before")
    @classmethod
2726
2727
2728
2729
2730
2731
2732
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2733
2734
2735
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2736
            raise ValueError("Stream options can only be defined when `stream=True`.")
2737
2738

        return data
2739
2740
2741


# Transcription response objects
2742
2743
2744
2745
2746
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2747
2748
2749
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2750
    usage: TranscriptionUsageAudio
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2802
    tokens: list[int]
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2816
    segments: list[TranscriptionSegment] | None = None
2817
2818
    """Segments of the transcribed text and their corresponding details."""

2819
    words: list[TranscriptionWord] | None = None
2820
    """Extracted words and their corresponding timestamps."""
2821
2822
2823
2824


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2825
2826
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2827
2828
2829
2830
2831
2832
2833
2834


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2835
    usage: UsageInfo | None = Field(default=None)
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2848
    model: str | None = None
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2868
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2869
2870
    """The seed to use for sampling."""

2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2882
    language: str | None = None
2883
2884
2885
2886
2887
2888
2889
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2890
    to_language: str | None = None
2891
2892
2893
2894
2895
2896
2897
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2898
    stream: bool | None = False
2899
    """Custom field not present in the original OpenAI definition. When set,
2900
    it will enable output to be streamed in a similar fashion as the Chat
2901
    Completion endpoint.
2902
2903
    """
    # Flattened stream option to simplify form data.
2904
2905
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2906
2907
2908
2909
2910
2911
2912
2913
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2914
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2915
    ) -> SamplingParams:
2916
2917
2918
2919
2920
2921
2922
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2923
2924
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2925

2926
2927
2928
2929
2930
2931
2932
2933
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2934
2935
2936
2937
2938
2939
2940

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2941
            raise ValueError("Stream options can only be defined when `stream=True`.")
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

3015
    segments: list[TranslationSegment] | None = None
3016
3017
    """Segments of the translated text and their corresponding details."""

3018
    words: list[TranslationWord] | None = None
3019
    """Extracted words and their corresponding timestamps."""